Spaces:
Running
vulkan: Add bfloat16 support (llama/12554)
Browse files* vulkan: Add bfloat16 support
This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.
It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.
The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.
* vulkan: Support bf16 tensors without the bf16 extension or coopmat support
Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.
* vulkan: bfloat16 fixes (really works without bfloat16 support now)
* vulkan: fix spirv-val failure and reenable -O
- ggml/src/ggml-vulkan/CMakeLists.txt +17 -0
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +175 -21
- ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -0
- ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +9 -2
- ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +4 -1
- ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +7 -1
- ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +7 -2
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +1 -1
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +31 -8
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +23 -14
- ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- ggml/src/ggml-vulkan/vulkan-shaders/types.comp +27 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -15
|
@@ -71,6 +71,22 @@ if (Vulkan_FOUND)
|
|
| 71 |
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 72 |
endif()
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
| 75 |
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
| 76 |
|
|
@@ -142,6 +158,7 @@ if (Vulkan_FOUND)
|
|
| 142 |
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
| 143 |
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
| 144 |
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
|
|
|
| 145 |
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
| 146 |
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
| 147 |
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
|
|
|
| 71 |
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 72 |
endif()
|
| 73 |
|
| 74 |
+
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
|
| 75 |
+
# If it's not, there will be an error to stderr.
|
| 76 |
+
# If it's supported, set a define to indicate that we should compile those shaders
|
| 77 |
+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
| 78 |
+
OUTPUT_VARIABLE glslc_output
|
| 79 |
+
ERROR_VARIABLE glslc_error)
|
| 80 |
+
|
| 81 |
+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
|
| 82 |
+
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
|
| 83 |
+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
|
| 84 |
+
else()
|
| 85 |
+
message(STATUS "GL_EXT_bfloat16 supported by glslc")
|
| 86 |
+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
|
| 87 |
+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 88 |
+
endif()
|
| 89 |
+
|
| 90 |
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
| 91 |
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
| 92 |
|
|
|
|
| 158 |
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
| 159 |
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
| 160 |
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
| 161 |
+
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
|
| 162 |
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
| 163 |
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
| 164 |
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
|
@@ -51,6 +51,24 @@
|
|
| 51 |
|
| 52 |
#include "ggml-vulkan-shaders.hpp"
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
| 55 |
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
| 56 |
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
@@ -266,8 +284,9 @@ struct vk_device_struct {
|
|
| 266 |
bool subgroup_require_full_support;
|
| 267 |
|
| 268 |
bool coopmat_support;
|
| 269 |
-
bool coopmat_acc_f32_support;
|
| 270 |
-
bool coopmat_acc_f16_support;
|
|
|
|
| 271 |
uint32_t coopmat_m;
|
| 272 |
uint32_t coopmat_n;
|
| 273 |
uint32_t coopmat_k;
|
|
@@ -293,6 +312,7 @@ struct vk_device_struct {
|
|
| 293 |
|
| 294 |
vk_matmul_pipeline pipeline_matmul_f32 {};
|
| 295 |
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
|
|
|
| 296 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
| 297 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
| 298 |
|
|
@@ -301,6 +321,7 @@ struct vk_device_struct {
|
|
| 301 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
| 302 |
|
| 303 |
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
|
|
|
| 304 |
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
| 305 |
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
| 306 |
|
|
@@ -333,8 +354,8 @@ struct vk_device_struct {
|
|
| 333 |
vk_pipeline pipeline_clamp_f32;
|
| 334 |
vk_pipeline pipeline_pad_f32;
|
| 335 |
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
| 336 |
-
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
| 337 |
-
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
| 338 |
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
| 339 |
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
| 340 |
vk_pipeline pipeline_norm_f32;
|
|
@@ -1811,6 +1832,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1811 |
if (!device->pipeline_matmul_id_f32) {
|
| 1812 |
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1813 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1814 |
|
| 1815 |
std::vector<std::future<void>> compiles;
|
| 1816 |
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
@@ -1920,6 +1947,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1920 |
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1921 |
|
| 1922 |
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1923 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1924 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1925 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
@@ -1941,6 +1973,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1941 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1942 |
|
| 1943 |
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1944 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1945 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1946 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
@@ -1994,6 +2031,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1994 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1995 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1996 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1997 |
|
| 1998 |
if (device->coopmat_acc_f16_support) {
|
| 1999 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2042,6 +2084,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2042 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2043 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2044 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2045 |
|
| 2046 |
if (device->coopmat_acc_f16_support) {
|
| 2047 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2124,6 +2171,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2124 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2125 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2126 |
|
|
|
|
|
|
|
| 2127 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2128 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2129 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2159,6 +2208,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2159 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2160 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2161 |
|
|
|
|
|
|
|
| 2162 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2163 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2164 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2211,6 +2262,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2211 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2212 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2213 |
|
|
|
|
|
|
|
| 2214 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2215 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2216 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2246,6 +2299,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2246 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2247 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2248 |
|
|
|
|
|
|
|
| 2249 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2250 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2251 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2266,8 +2321,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2266 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2267 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2268 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2269 |
-
#undef CREATE_MM
|
| 2270 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2271 |
|
| 2272 |
// mul mat vec
|
| 2273 |
|
|
@@ -2286,6 +2359,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2286 |
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
| 2287 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2288 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
|
|
| 2289 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2290 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2291 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2308,6 +2382,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2308 |
|
| 2309 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2310 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
|
|
| 2311 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2312 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2313 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2331,6 +2406,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2331 |
|
| 2332 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
| 2333 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
|
|
| 2334 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
| 2335 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
| 2336 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
@@ -2376,6 +2452,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2376 |
// get_rows
|
| 2377 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2378 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
|
|
| 2379 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2380 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2381 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2393,6 +2470,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2393 |
|
| 2394 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2395 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
|
|
| 2396 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2397 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2398 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2430,10 +2508,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2430 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2431 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2432 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 2433 |
|
| 2434 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2435 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2436 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
|
|
|
| 2437 |
if (device->float_controls_rte_fp16) {
|
| 2438 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
| 2439 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
|
@@ -2601,6 +2682,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2601 |
bool coopmat2_support = false;
|
| 2602 |
device->coopmat_support = false;
|
| 2603 |
device->integer_dot_product = false;
|
|
|
|
| 2604 |
|
| 2605 |
for (const auto& properties : ext_props) {
|
| 2606 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
@@ -2631,6 +2713,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2631 |
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
| 2632 |
device->integer_dot_product = true;
|
| 2633 |
#endif
|
|
|
|
|
|
|
|
|
|
| 2634 |
}
|
| 2635 |
}
|
| 2636 |
|
|
@@ -2817,6 +2902,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2817 |
}
|
| 2818 |
#endif
|
| 2819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2820 |
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
| 2821 |
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
| 2822 |
if (maintenance4_support) {
|
|
@@ -3014,6 +3110,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 3014 |
device->coopmat_int_n = prop.NSize;
|
| 3015 |
device->coopmat_int_k = prop.KSize;
|
| 3016 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3017 |
}
|
| 3018 |
|
| 3019 |
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
|
@@ -3021,11 +3136,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 3021 |
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
| 3022 |
device->coopmat_support = false;
|
| 3023 |
}
|
|
|
|
|
|
|
|
|
|
| 3024 |
}
|
| 3025 |
|
| 3026 |
if (device->coopmat_support) {
|
| 3027 |
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
| 3028 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3029 |
#endif
|
| 3030 |
device->name = GGML_VK_NAME + std::to_string(idx);
|
| 3031 |
|
|
@@ -3482,6 +3605,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 3482 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 3483 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 3484 |
}
|
|
|
|
|
|
|
|
|
|
| 3485 |
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
| 3486 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 3487 |
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
|
@@ -3553,6 +3679,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
| 3553 |
switch (a_type) {
|
| 3554 |
case GGML_TYPE_F32:
|
| 3555 |
case GGML_TYPE_F16:
|
|
|
|
| 3556 |
case GGML_TYPE_Q4_0:
|
| 3557 |
case GGML_TYPE_Q4_1:
|
| 3558 |
case GGML_TYPE_Q5_0:
|
|
@@ -3585,6 +3712,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
| 3585 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 3586 |
return ctx->device->pipeline_matmul_id_f32;
|
| 3587 |
}
|
|
|
|
|
|
|
|
|
|
| 3588 |
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
| 3589 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 3590 |
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
|
@@ -3638,6 +3768,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
| 3638 |
switch (a_type) {
|
| 3639 |
case GGML_TYPE_F32:
|
| 3640 |
case GGML_TYPE_F16:
|
|
|
|
| 3641 |
case GGML_TYPE_Q4_0:
|
| 3642 |
case GGML_TYPE_Q4_1:
|
| 3643 |
case GGML_TYPE_Q5_0:
|
|
@@ -4373,6 +4504,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 4373 |
return ctx->device->pipeline_cpy_f16_f16;
|
| 4374 |
}
|
| 4375 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4376 |
if (src->type == GGML_TYPE_F32) {
|
| 4377 |
switch (to) {
|
| 4378 |
case GGML_TYPE_Q4_0:
|
|
@@ -4500,8 +4638,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4500 |
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
| 4501 |
!ggml_vk_dim01_contiguous(src0);
|
| 4502 |
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
|
|
| 4503 |
!ggml_vk_dim01_contiguous(src1);
|
| 4504 |
|
|
|
|
|
|
|
|
|
|
| 4505 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 4506 |
|
| 4507 |
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
|
@@ -4511,25 +4653,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4511 |
|
| 4512 |
if (mmp == nullptr) {
|
| 4513 |
// Fall back to f16 dequant mul mat
|
| 4514 |
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ?
|
| 4515 |
quantize_y = false;
|
| 4516 |
}
|
| 4517 |
|
| 4518 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 4519 |
-
const bool qy_needs_dequant = !quantize_y && ((src1->type !=
|
| 4520 |
|
| 4521 |
if (qx_needs_dequant) {
|
| 4522 |
// Fall back to dequant + f16 mulmat
|
| 4523 |
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx,
|
| 4524 |
}
|
| 4525 |
|
| 4526 |
// Not implemented
|
| 4527 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 4528 |
|
| 4529 |
-
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ?
|
| 4530 |
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 4531 |
|
| 4532 |
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ?
|
| 4533 |
|
| 4534 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 4535 |
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
|
@@ -4550,12 +4692,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4550 |
vk_pipeline to_q8_1 = nullptr;
|
| 4551 |
|
| 4552 |
if (x_non_contig) {
|
| 4553 |
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
| 4554 |
} else {
|
| 4555 |
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
| 4556 |
}
|
| 4557 |
if (y_non_contig) {
|
| 4558 |
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
| 4559 |
} else {
|
| 4560 |
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
| 4561 |
}
|
|
@@ -5055,7 +5197,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
| 5055 |
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
| 5056 |
// when ne12 and ne13 are one.
|
| 5057 |
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
| 5058 |
-
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
| 5059 |
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
| 5060 |
} else {
|
| 5061 |
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
@@ -5123,27 +5265,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
| 5123 |
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
| 5124 |
!ggml_vk_dim01_contiguous(src0);
|
| 5125 |
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
|
|
| 5126 |
!ggml_vk_dim01_contiguous(src1);
|
| 5127 |
|
|
|
|
|
|
|
|
|
|
| 5128 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 5129 |
|
| 5130 |
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ?
|
| 5131 |
|
| 5132 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 5133 |
-
const bool qy_needs_dequant = (src1->type !=
|
| 5134 |
|
| 5135 |
if (qx_needs_dequant) {
|
| 5136 |
// Fall back to dequant + f16 mulmat
|
| 5137 |
-
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx,
|
| 5138 |
}
|
| 5139 |
|
| 5140 |
// Not implemented
|
| 5141 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 5142 |
|
| 5143 |
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ?
|
| 5144 |
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
| 5145 |
|
| 5146 |
-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ?
|
| 5147 |
|
| 5148 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 5149 |
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
@@ -5162,12 +5308,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
| 5162 |
vk_pipeline to_fp16_vk_1 = nullptr;
|
| 5163 |
|
| 5164 |
if (x_non_contig) {
|
| 5165 |
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
| 5166 |
} else {
|
| 5167 |
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
| 5168 |
}
|
| 5169 |
if (y_non_contig) {
|
| 5170 |
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
| 5171 |
} else {
|
| 5172 |
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
| 5173 |
}
|
|
@@ -9295,6 +9441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9295 |
switch (src0_type) {
|
| 9296 |
case GGML_TYPE_F32:
|
| 9297 |
case GGML_TYPE_F16:
|
|
|
|
| 9298 |
case GGML_TYPE_Q4_0:
|
| 9299 |
case GGML_TYPE_Q4_1:
|
| 9300 |
case GGML_TYPE_Q5_0:
|
|
@@ -9330,10 +9477,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9330 |
if (a->ne[3] != b->ne[3]) {
|
| 9331 |
return false;
|
| 9332 |
}
|
| 9333 |
-
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
| 9334 |
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
| 9335 |
return false;
|
| 9336 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9337 |
|
| 9338 |
return true;
|
| 9339 |
} break;
|
|
@@ -9406,6 +9558,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9406 |
switch (op->src[0]->type) {
|
| 9407 |
case GGML_TYPE_F32:
|
| 9408 |
case GGML_TYPE_F16:
|
|
|
|
| 9409 |
case GGML_TYPE_Q4_0:
|
| 9410 |
case GGML_TYPE_Q4_1:
|
| 9411 |
case GGML_TYPE_Q5_0:
|
|
@@ -9436,6 +9589,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9436 |
switch (src1_type) {
|
| 9437 |
case GGML_TYPE_F32:
|
| 9438 |
case GGML_TYPE_F16:
|
|
|
|
| 9439 |
case GGML_TYPE_Q4_0:
|
| 9440 |
case GGML_TYPE_Q4_1:
|
| 9441 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 51 |
|
| 52 |
#include "ggml-vulkan-shaders.hpp"
|
| 53 |
|
| 54 |
+
// remove this once it's more widely available in the SDK
|
| 55 |
+
#if !defined(VK_KHR_shader_bfloat16)
|
| 56 |
+
|
| 57 |
+
#define VK_KHR_shader_bfloat16 1
|
| 58 |
+
#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
|
| 59 |
+
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
|
| 60 |
+
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
|
| 61 |
+
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
|
| 62 |
+
|
| 63 |
+
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
| 64 |
+
VkStructureType sType;
|
| 65 |
+
void* pNext;
|
| 66 |
+
VkBool32 shaderBFloat16Type;
|
| 67 |
+
VkBool32 shaderBFloat16DotProduct;
|
| 68 |
+
VkBool32 shaderBFloat16CooperativeMatrix;
|
| 69 |
+
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
| 73 |
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
| 74 |
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
|
|
| 284 |
bool subgroup_require_full_support;
|
| 285 |
|
| 286 |
bool coopmat_support;
|
| 287 |
+
bool coopmat_acc_f32_support {};
|
| 288 |
+
bool coopmat_acc_f16_support {};
|
| 289 |
+
bool coopmat_bf16_support {};
|
| 290 |
uint32_t coopmat_m;
|
| 291 |
uint32_t coopmat_n;
|
| 292 |
uint32_t coopmat_k;
|
|
|
|
| 312 |
|
| 313 |
vk_matmul_pipeline pipeline_matmul_f32 {};
|
| 314 |
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
| 315 |
+
vk_matmul_pipeline pipeline_matmul_bf16 {};
|
| 316 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
| 317 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
| 318 |
|
|
|
|
| 321 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
| 322 |
|
| 323 |
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
| 324 |
+
vk_matmul_pipeline pipeline_matmul_id_bf16 {};
|
| 325 |
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
| 326 |
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
| 327 |
|
|
|
|
| 354 |
vk_pipeline pipeline_clamp_f32;
|
| 355 |
vk_pipeline pipeline_pad_f32;
|
| 356 |
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
| 357 |
+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
|
| 358 |
+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
|
| 359 |
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
| 360 |
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
| 361 |
vk_pipeline pipeline_norm_f32;
|
|
|
|
| 1832 |
if (!device->pipeline_matmul_id_f32) {
|
| 1833 |
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1834 |
}
|
| 1835 |
+
if (!device->pipeline_matmul_bf16) {
|
| 1836 |
+
device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1837 |
+
}
|
| 1838 |
+
if (!device->pipeline_matmul_id_bf16) {
|
| 1839 |
+
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1840 |
+
}
|
| 1841 |
|
| 1842 |
std::vector<std::future<void>> compiles;
|
| 1843 |
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
|
|
| 1947 |
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1948 |
|
| 1949 |
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1950 |
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 1951 |
+
if (device->coopmat_bf16_support) {
|
| 1952 |
+
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1953 |
+
}
|
| 1954 |
+
#endif
|
| 1955 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1956 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1957 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
|
|
| 1973 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1974 |
|
| 1975 |
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1976 |
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 1977 |
+
if (device->coopmat_bf16_support) {
|
| 1978 |
+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1979 |
+
}
|
| 1980 |
+
#endif
|
| 1981 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1982 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1983 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
|
|
| 2031 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2032 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2033 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2034 |
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 2035 |
+
if (device->coopmat_bf16_support) {
|
| 2036 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
|
| 2037 |
+
}
|
| 2038 |
+
#endif
|
| 2039 |
|
| 2040 |
if (device->coopmat_acc_f16_support) {
|
| 2041 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2084 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2085 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2086 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2087 |
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 2088 |
+
if (device->coopmat_bf16_support) {
|
| 2089 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2090 |
+
}
|
| 2091 |
+
#endif
|
| 2092 |
|
| 2093 |
if (device->coopmat_acc_f16_support) {
|
| 2094 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2171 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2172 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2173 |
|
| 2174 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2175 |
+
|
| 2176 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2177 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2178 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2208 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2209 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2210 |
|
| 2211 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
| 2212 |
+
|
| 2213 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2214 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2215 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2262 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2263 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2264 |
|
| 2265 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2266 |
+
|
| 2267 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2268 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2269 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2299 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2300 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2301 |
|
| 2302 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
| 2303 |
+
|
| 2304 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2305 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2306 |
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2321 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2322 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2323 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2324 |
}
|
| 2325 |
+
// reusing CREATE_MM from the fp32 path
|
| 2326 |
+
if ((device->coopmat2 || device->coopmat_support)
|
| 2327 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2328 |
+
&& !device->coopmat_bf16_support
|
| 2329 |
+
#endif
|
| 2330 |
+
) {
|
| 2331 |
+
// use scalar tile sizes
|
| 2332 |
+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
| 2333 |
+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
| 2334 |
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
| 2335 |
+
|
| 2336 |
+
l_wg_denoms = {128, 128, 1 };
|
| 2337 |
+
m_wg_denoms = { 64, 64, 1 };
|
| 2338 |
+
s_wg_denoms = { 32, 32, 1 };
|
| 2339 |
+
|
| 2340 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2341 |
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
| 2342 |
+
}
|
| 2343 |
+
#undef CREATE_MM
|
| 2344 |
|
| 2345 |
// mul mat vec
|
| 2346 |
|
|
|
|
| 2359 |
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
| 2360 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2361 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2362 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2363 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2364 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2365 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
|
|
| 2382 |
|
| 2383 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2384 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2385 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2386 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2387 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
| 2388 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
|
|
| 2406 |
|
| 2407 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
| 2408 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
| 2409 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
| 2410 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
| 2411 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
| 2412 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
|
|
| 2452 |
// get_rows
|
| 2453 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2454 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2455 |
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2456 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2457 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2458 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
| 2470 |
|
| 2471 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2472 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2473 |
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2474 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2475 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2476 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
| 2508 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2509 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2510 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2511 |
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2512 |
|
| 2513 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2514 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2515 |
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2516 |
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2517 |
+
|
| 2518 |
if (device->float_controls_rte_fp16) {
|
| 2519 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
| 2520 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
|
|
|
| 2682 |
bool coopmat2_support = false;
|
| 2683 |
device->coopmat_support = false;
|
| 2684 |
device->integer_dot_product = false;
|
| 2685 |
+
bool bfloat16_support = false;
|
| 2686 |
|
| 2687 |
for (const auto& properties : ext_props) {
|
| 2688 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
|
|
| 2713 |
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
| 2714 |
device->integer_dot_product = true;
|
| 2715 |
#endif
|
| 2716 |
+
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
| 2717 |
+
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
| 2718 |
+
bfloat16_support = true;
|
| 2719 |
}
|
| 2720 |
}
|
| 2721 |
|
|
|
|
| 2902 |
}
|
| 2903 |
#endif
|
| 2904 |
|
| 2905 |
+
#if defined(VK_KHR_shader_bfloat16)
|
| 2906 |
+
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
| 2907 |
+
bfloat16_features.pNext = nullptr;
|
| 2908 |
+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
| 2909 |
+
if (bfloat16_support) {
|
| 2910 |
+
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
| 2911 |
+
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
| 2912 |
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
| 2913 |
+
}
|
| 2914 |
+
#endif
|
| 2915 |
+
|
| 2916 |
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
| 2917 |
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
| 2918 |
if (maintenance4_support) {
|
|
|
|
| 3110 |
device->coopmat_int_n = prop.NSize;
|
| 3111 |
device->coopmat_int_k = prop.KSize;
|
| 3112 |
}
|
| 3113 |
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 3114 |
+
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
| 3115 |
+
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
| 3116 |
+
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
| 3117 |
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
| 3118 |
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
| 3119 |
+
) {
|
| 3120 |
+
// coopmat sizes not set yet
|
| 3121 |
+
if (device->coopmat_m == 0) {
|
| 3122 |
+
device->coopmat_bf16_support = true;
|
| 3123 |
+
device->coopmat_m = prop.MSize;
|
| 3124 |
+
device->coopmat_n = prop.NSize;
|
| 3125 |
+
device->coopmat_k = prop.KSize;
|
| 3126 |
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
| 3127 |
+
// Only enable if shape is identical
|
| 3128 |
+
device->coopmat_bf16_support = true;
|
| 3129 |
+
}
|
| 3130 |
+
}
|
| 3131 |
+
#endif
|
| 3132 |
}
|
| 3133 |
|
| 3134 |
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
|
|
|
| 3136 |
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
| 3137 |
device->coopmat_support = false;
|
| 3138 |
}
|
| 3139 |
+
if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
| 3140 |
+
device->coopmat_bf16_support = false;
|
| 3141 |
+
}
|
| 3142 |
}
|
| 3143 |
|
| 3144 |
if (device->coopmat_support) {
|
| 3145 |
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
| 3146 |
}
|
| 3147 |
+
#if defined(VK_KHR_shader_bfloat16)
|
| 3148 |
+
if (device->coopmat_bf16_support) {
|
| 3149 |
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
| 3150 |
+
}
|
| 3151 |
+
#endif
|
| 3152 |
#endif
|
| 3153 |
device->name = GGML_VK_NAME + std::to_string(idx);
|
| 3154 |
|
|
|
|
| 3605 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 3606 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 3607 |
}
|
| 3608 |
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
| 3609 |
+
return ctx->device->pipeline_matmul_bf16;
|
| 3610 |
+
}
|
| 3611 |
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
| 3612 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 3613 |
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
|
|
|
| 3679 |
switch (a_type) {
|
| 3680 |
case GGML_TYPE_F32:
|
| 3681 |
case GGML_TYPE_F16:
|
| 3682 |
+
case GGML_TYPE_BF16:
|
| 3683 |
case GGML_TYPE_Q4_0:
|
| 3684 |
case GGML_TYPE_Q4_1:
|
| 3685 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 3712 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 3713 |
return ctx->device->pipeline_matmul_id_f32;
|
| 3714 |
}
|
| 3715 |
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
| 3716 |
+
return ctx->device->pipeline_matmul_id_bf16;
|
| 3717 |
+
}
|
| 3718 |
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
| 3719 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 3720 |
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
|
|
|
| 3768 |
switch (a_type) {
|
| 3769 |
case GGML_TYPE_F32:
|
| 3770 |
case GGML_TYPE_F16:
|
| 3771 |
+
case GGML_TYPE_BF16:
|
| 3772 |
case GGML_TYPE_Q4_0:
|
| 3773 |
case GGML_TYPE_Q4_1:
|
| 3774 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 4504 |
return ctx->device->pipeline_cpy_f16_f16;
|
| 4505 |
}
|
| 4506 |
}
|
| 4507 |
+
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
|
| 4508 |
+
if (contig) {
|
| 4509 |
+
return ctx->device->pipeline_contig_cpy_f32_bf16;
|
| 4510 |
+
} else {
|
| 4511 |
+
return ctx->device->pipeline_cpy_f32_bf16;
|
| 4512 |
+
}
|
| 4513 |
+
}
|
| 4514 |
if (src->type == GGML_TYPE_F32) {
|
| 4515 |
switch (to) {
|
| 4516 |
case GGML_TYPE_Q4_0:
|
|
|
|
| 4638 |
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
| 4639 |
!ggml_vk_dim01_contiguous(src0);
|
| 4640 |
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
| 4641 |
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
| 4642 |
!ggml_vk_dim01_contiguous(src1);
|
| 4643 |
|
| 4644 |
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
| 4645 |
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
| 4646 |
+
|
| 4647 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 4648 |
|
| 4649 |
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
|
|
|
| 4653 |
|
| 4654 |
if (mmp == nullptr) {
|
| 4655 |
// Fall back to f16 dequant mul mat
|
| 4656 |
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
| 4657 |
quantize_y = false;
|
| 4658 |
}
|
| 4659 |
|
| 4660 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 4661 |
+
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
| 4662 |
|
| 4663 |
if (qx_needs_dequant) {
|
| 4664 |
// Fall back to dequant + f16 mulmat
|
| 4665 |
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
| 4666 |
}
|
| 4667 |
|
| 4668 |
// Not implemented
|
| 4669 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 4670 |
|
| 4671 |
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
| 4672 |
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 4673 |
|
| 4674 |
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
| 4675 |
|
| 4676 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 4677 |
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
|
|
|
| 4692 |
vk_pipeline to_q8_1 = nullptr;
|
| 4693 |
|
| 4694 |
if (x_non_contig) {
|
| 4695 |
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
| 4696 |
} else {
|
| 4697 |
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
| 4698 |
}
|
| 4699 |
if (y_non_contig) {
|
| 4700 |
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
| 4701 |
} else {
|
| 4702 |
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
| 4703 |
}
|
|
|
|
| 5197 |
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
| 5198 |
// when ne12 and ne13 are one.
|
| 5199 |
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
| 5200 |
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
| 5201 |
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
| 5202 |
} else {
|
| 5203 |
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
|
|
| 5265 |
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
| 5266 |
!ggml_vk_dim01_contiguous(src0);
|
| 5267 |
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
| 5268 |
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
| 5269 |
!ggml_vk_dim01_contiguous(src1);
|
| 5270 |
|
| 5271 |
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
| 5272 |
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
| 5273 |
+
|
| 5274 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 5275 |
|
| 5276 |
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
| 5277 |
|
| 5278 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 5279 |
+
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
| 5280 |
|
| 5281 |
if (qx_needs_dequant) {
|
| 5282 |
// Fall back to dequant + f16 mulmat
|
| 5283 |
+
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
| 5284 |
}
|
| 5285 |
|
| 5286 |
// Not implemented
|
| 5287 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 5288 |
|
| 5289 |
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
| 5290 |
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
| 5291 |
|
| 5292 |
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
| 5293 |
|
| 5294 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 5295 |
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
|
|
| 5308 |
vk_pipeline to_fp16_vk_1 = nullptr;
|
| 5309 |
|
| 5310 |
if (x_non_contig) {
|
| 5311 |
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
| 5312 |
} else {
|
| 5313 |
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
| 5314 |
}
|
| 5315 |
if (y_non_contig) {
|
| 5316 |
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
| 5317 |
} else {
|
| 5318 |
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
| 5319 |
}
|
|
|
|
| 9441 |
switch (src0_type) {
|
| 9442 |
case GGML_TYPE_F32:
|
| 9443 |
case GGML_TYPE_F16:
|
| 9444 |
+
case GGML_TYPE_BF16:
|
| 9445 |
case GGML_TYPE_Q4_0:
|
| 9446 |
case GGML_TYPE_Q4_1:
|
| 9447 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 9477 |
if (a->ne[3] != b->ne[3]) {
|
| 9478 |
return false;
|
| 9479 |
}
|
| 9480 |
+
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
|
| 9481 |
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
| 9482 |
return false;
|
| 9483 |
}
|
| 9484 |
+
if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
|
| 9485 |
+
// We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
|
| 9486 |
+
// So don't support this combination for now.
|
| 9487 |
+
return false;
|
| 9488 |
+
}
|
| 9489 |
|
| 9490 |
return true;
|
| 9491 |
} break;
|
|
|
|
| 9558 |
switch (op->src[0]->type) {
|
| 9559 |
case GGML_TYPE_F32:
|
| 9560 |
case GGML_TYPE_F16:
|
| 9561 |
+
case GGML_TYPE_BF16:
|
| 9562 |
case GGML_TYPE_Q4_0:
|
| 9563 |
case GGML_TYPE_Q4_1:
|
| 9564 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 9589 |
switch (src1_type) {
|
| 9590 |
case GGML_TYPE_F32:
|
| 9591 |
case GGML_TYPE_F16:
|
| 9592 |
+
case GGML_TYPE_BF16:
|
| 9593 |
case GGML_TYPE_Q4_0:
|
| 9594 |
case GGML_TYPE_Q4_1:
|
| 9595 |
case GGML_TYPE_Q5_0:
|
|
@@ -12,6 +12,9 @@ endif()
|
|
| 12 |
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 13 |
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 14 |
endif()
|
|
|
|
|
|
|
|
|
|
| 15 |
set(TARGET vulkan-shaders-gen)
|
| 16 |
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
| 17 |
install(TARGETS ${TARGET} RUNTIME)
|
|
|
|
| 12 |
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 13 |
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 14 |
endif()
|
| 15 |
+
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 16 |
+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 17 |
+
endif()
|
| 18 |
set(TARGET vulkan-shaders-gen)
|
| 19 |
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
| 20 |
install(TARGETS ${TARGET} RUNTIME)
|
|
@@ -18,7 +18,11 @@ void main() {
|
|
| 18 |
// fast path for when all four iterations are in-bounds
|
| 19 |
if (idx + (num_iter-1)*num_threads < p.ne) {
|
| 20 |
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
| 23 |
#else
|
| 24 |
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
|
@@ -31,7 +35,10 @@ void main() {
|
|
| 31 |
continue;
|
| 32 |
}
|
| 33 |
|
| 34 |
-
#
|
|
|
|
|
|
|
|
|
|
| 35 |
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
| 36 |
#else
|
| 37 |
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
|
|
|
| 18 |
// fast path for when all four iterations are in-bounds
|
| 19 |
if (idx + (num_iter-1)*num_threads < p.ne) {
|
| 20 |
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
| 21 |
+
|
| 22 |
+
#if defined(DATA_D_BF16)
|
| 23 |
+
float f = float(data_a[get_aoffset() + idx]);
|
| 24 |
+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
| 25 |
+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
| 26 |
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
| 27 |
#else
|
| 28 |
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
|
|
|
| 35 |
continue;
|
| 36 |
}
|
| 37 |
|
| 38 |
+
#if defined(DATA_D_BF16)
|
| 39 |
+
float f = float(data_a[get_aoffset() + idx]);
|
| 40 |
+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
| 41 |
+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
| 42 |
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
| 43 |
#else
|
| 44 |
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
|
@@ -12,7 +12,10 @@ void main() {
|
|
| 12 |
return;
|
| 13 |
}
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
| 16 |
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
| 17 |
#else
|
| 18 |
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
|
|
|
| 12 |
return;
|
| 13 |
}
|
| 14 |
|
| 15 |
+
#if defined(DATA_D_BF16)
|
| 16 |
+
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
|
| 17 |
+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
|
| 18 |
+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
| 19 |
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
| 20 |
#else
|
| 21 |
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
|
@@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|
| 23 |
}
|
| 24 |
#endif
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
#if defined(DATA_A_Q4_0)
|
| 27 |
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
| 28 |
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
|
@@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|
| 428 |
}
|
| 429 |
#endif
|
| 430 |
|
| 431 |
-
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
| 432 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 433 |
return vec2(0, 0);
|
| 434 |
}
|
|
|
|
| 23 |
}
|
| 24 |
#endif
|
| 25 |
|
| 26 |
+
#if defined(DATA_A_BF16)
|
| 27 |
+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
| 28 |
+
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
|
| 29 |
+
}
|
| 30 |
+
#endif
|
| 31 |
+
|
| 32 |
#if defined(DATA_A_Q4_0)
|
| 33 |
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
| 34 |
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
|
|
|
| 434 |
}
|
| 435 |
#endif
|
| 436 |
|
| 437 |
+
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
| 438 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 439 |
return vec2(0, 0);
|
| 440 |
}
|
|
@@ -20,9 +20,14 @@ void main() {
|
|
| 20 |
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
| 21 |
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
| 24 |
-
data_d[d_offset + i00] = D_TYPE(
|
| 25 |
#else
|
| 26 |
-
data_d[d_offset + i00] =
|
| 27 |
#endif
|
| 28 |
}
|
|
|
|
| 20 |
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
| 21 |
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
| 22 |
|
| 23 |
+
#if defined(DATA_A_BF16)
|
| 24 |
+
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
| 25 |
+
#else
|
| 26 |
+
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
| 27 |
+
#endif
|
| 28 |
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
| 29 |
+
data_d[d_offset + i00] = D_TYPE(v);
|
| 30 |
#else
|
| 31 |
+
data_d[d_offset + i00] = D_TYPE(v);
|
| 32 |
#endif
|
| 33 |
}
|
|
@@ -6,7 +6,7 @@
|
|
| 6 |
|
| 7 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 8 |
|
| 9 |
-
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
| 10 |
#define K_PER_ITER 8
|
| 11 |
#else
|
| 12 |
#define K_PER_ITER 2
|
|
|
|
| 6 |
|
| 7 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 8 |
|
| 9 |
+
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
|
| 10 |
#define K_PER_ITER 8
|
| 11 |
#else
|
| 12 |
#define K_PER_ITER 2
|
|
@@ -10,6 +10,10 @@
|
|
| 10 |
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 11 |
#endif
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
#ifdef COOPMAT
|
| 14 |
#extension GL_KHR_cooperative_matrix : enable
|
| 15 |
#extension GL_KHR_memory_scope_semantics : enable
|
|
@@ -29,6 +33,10 @@
|
|
| 29 |
#define LOAD_VEC_B 1
|
| 30 |
#endif
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 33 |
|
| 34 |
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
@@ -202,8 +210,8 @@ void main() {
|
|
| 202 |
#endif
|
| 203 |
|
| 204 |
#ifdef COOPMAT
|
| 205 |
-
coopmat<
|
| 206 |
-
coopmat<
|
| 207 |
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
| 208 |
|
| 209 |
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
|
@@ -248,6 +256,21 @@ void main() {
|
|
| 248 |
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
| 249 |
}
|
| 250 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
#elif defined(DATA_A_Q4_0)
|
| 252 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 253 |
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
|
@@ -695,13 +718,13 @@ void main() {
|
|
| 695 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 696 |
#endif
|
| 697 |
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
| 698 |
-
buf_b[buf_idx + 0] =
|
| 699 |
-
buf_b[buf_idx + 1] =
|
| 700 |
-
buf_b[buf_idx + 2] =
|
| 701 |
-
buf_b[buf_idx + 3] =
|
| 702 |
#elif !MUL_MAT_ID
|
| 703 |
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
| 704 |
-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] =
|
| 705 |
} else {
|
| 706 |
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 707 |
}
|
|
@@ -709,7 +732,7 @@ void main() {
|
|
| 709 |
const uint row_i = ic * BN + loadc_b + l;
|
| 710 |
if (row_i < _ne1) {
|
| 711 |
const u16vec2 row_idx = row_ids[row_i];
|
| 712 |
-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] =
|
| 713 |
} else {
|
| 714 |
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 715 |
}
|
|
|
|
| 10 |
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 11 |
#endif
|
| 12 |
|
| 13 |
+
#if defined(DATA_A_BF16) && defined(COOPMAT)
|
| 14 |
+
#extension GL_EXT_bfloat16 : enable
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
#ifdef COOPMAT
|
| 18 |
#extension GL_KHR_cooperative_matrix : enable
|
| 19 |
#extension GL_KHR_memory_scope_semantics : enable
|
|
|
|
| 33 |
#define LOAD_VEC_B 1
|
| 34 |
#endif
|
| 35 |
|
| 36 |
+
#if !defined(TO_FLOAT_TYPE)
|
| 37 |
+
#define TO_FLOAT_TYPE FLOAT_TYPE
|
| 38 |
+
#endif
|
| 39 |
+
|
| 40 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 41 |
|
| 42 |
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
|
|
| 210 |
#endif
|
| 211 |
|
| 212 |
#ifdef COOPMAT
|
| 213 |
+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
| 214 |
+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
| 215 |
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
| 216 |
|
| 217 |
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
|
|
|
| 256 |
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
| 257 |
}
|
| 258 |
#endif
|
| 259 |
+
#elif defined(DATA_A_BF16)
|
| 260 |
+
#if LOAD_VEC_A == 4
|
| 261 |
+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 262 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 263 |
+
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
|
| 264 |
+
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
|
| 265 |
+
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
|
| 266 |
+
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
|
| 267 |
+
#else
|
| 268 |
+
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
| 269 |
+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
| 270 |
+
} else {
|
| 271 |
+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
|
| 272 |
+
}
|
| 273 |
+
#endif
|
| 274 |
#elif defined(DATA_A_Q4_0)
|
| 275 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 276 |
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
|
|
|
| 718 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 719 |
#endif
|
| 720 |
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
| 721 |
+
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
|
| 722 |
+
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
|
| 723 |
+
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
|
| 724 |
+
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
|
| 725 |
#elif !MUL_MAT_ID
|
| 726 |
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
| 727 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
| 728 |
} else {
|
| 729 |
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 730 |
}
|
|
|
|
| 732 |
const uint row_i = ic * BN + loadc_b + l;
|
| 733 |
if (row_i < _ne1) {
|
| 734 |
const u16vec2 row_idx = row_ids[row_i];
|
| 735 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
| 736 |
} else {
|
| 737 |
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 738 |
}
|
|
@@ -14,6 +14,9 @@
|
|
| 14 |
#extension GL_EXT_buffer_reference : enable
|
| 15 |
#extension GL_KHR_shader_subgroup_ballot : enable
|
| 16 |
#extension GL_KHR_shader_subgroup_vote : enable
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
#include "types.comp"
|
| 19 |
|
|
@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|
| 80 |
#define store_scales(a)
|
| 81 |
#endif
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
#ifdef MUL_MAT_ID
|
| 84 |
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
| 85 |
|
|
@@ -271,8 +280,8 @@ void main() {
|
|
| 271 |
|
| 272 |
// Manually partial unroll
|
| 273 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 274 |
-
coopmat<
|
| 275 |
-
coopmat<
|
| 276 |
|
| 277 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 278 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
@@ -286,8 +295,8 @@ void main() {
|
|
| 286 |
store_scales(tid);
|
| 287 |
}
|
| 288 |
while (block_k < end_k) {
|
| 289 |
-
coopmat<
|
| 290 |
-
coopmat<
|
| 291 |
|
| 292 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 293 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
@@ -310,8 +319,8 @@ void main() {
|
|
| 310 |
|
| 311 |
// Manually partial unroll
|
| 312 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 313 |
-
coopmat<
|
| 314 |
-
coopmat<
|
| 315 |
|
| 316 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 317 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
@@ -325,8 +334,8 @@ void main() {
|
|
| 325 |
store_scales(tid);
|
| 326 |
}
|
| 327 |
while (block_k < end_k) {
|
| 328 |
-
coopmat<
|
| 329 |
-
coopmat<
|
| 330 |
|
| 331 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 332 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
@@ -350,8 +359,8 @@ void main() {
|
|
| 350 |
|
| 351 |
// Manually partial unroll
|
| 352 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 353 |
-
coopmat<
|
| 354 |
-
coopmat<
|
| 355 |
|
| 356 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 357 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
@@ -365,8 +374,8 @@ void main() {
|
|
| 365 |
store_scales(tid);
|
| 366 |
}
|
| 367 |
while (block_k < end_k) {
|
| 368 |
-
coopmat<
|
| 369 |
-
coopmat<
|
| 370 |
|
| 371 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 372 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
@@ -405,8 +414,8 @@ void main() {
|
|
| 405 |
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
| 406 |
}
|
| 407 |
|
| 408 |
-
coopmat<
|
| 409 |
-
coopmat<
|
| 410 |
|
| 411 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 412 |
#ifdef MUL_MAT_ID
|
|
|
|
| 14 |
#extension GL_EXT_buffer_reference : enable
|
| 15 |
#extension GL_KHR_shader_subgroup_ballot : enable
|
| 16 |
#extension GL_KHR_shader_subgroup_vote : enable
|
| 17 |
+
#ifdef DATA_A_BF16
|
| 18 |
+
#extension GL_EXT_bfloat16 : enable
|
| 19 |
+
#endif
|
| 20 |
|
| 21 |
#include "types.comp"
|
| 22 |
|
|
|
|
| 83 |
#define store_scales(a)
|
| 84 |
#endif
|
| 85 |
|
| 86 |
+
#if defined(DATA_A_BF16)
|
| 87 |
+
#define MAT_TYPE bfloat16_t
|
| 88 |
+
#else
|
| 89 |
+
#define MAT_TYPE FLOAT_TYPE
|
| 90 |
+
#endif
|
| 91 |
+
|
| 92 |
#ifdef MUL_MAT_ID
|
| 93 |
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
| 94 |
|
|
|
|
| 280 |
|
| 281 |
// Manually partial unroll
|
| 282 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 283 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 284 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
| 285 |
|
| 286 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 287 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
|
|
| 295 |
store_scales(tid);
|
| 296 |
}
|
| 297 |
while (block_k < end_k) {
|
| 298 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 299 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
| 300 |
|
| 301 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 302 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
|
|
| 319 |
|
| 320 |
// Manually partial unroll
|
| 321 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 322 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 323 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
| 324 |
|
| 325 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 326 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
|
|
| 334 |
store_scales(tid);
|
| 335 |
}
|
| 336 |
while (block_k < end_k) {
|
| 337 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 338 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
| 339 |
|
| 340 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 341 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
|
|
| 359 |
|
| 360 |
// Manually partial unroll
|
| 361 |
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
| 362 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 363 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
| 364 |
|
| 365 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 366 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
| 374 |
store_scales(tid);
|
| 375 |
}
|
| 376 |
while (block_k < end_k) {
|
| 377 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 378 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
| 379 |
|
| 380 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 381 |
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
| 414 |
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
| 415 |
}
|
| 416 |
|
| 417 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 418 |
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
| 419 |
|
| 420 |
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 421 |
#ifdef MUL_MAT_ID
|
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 460
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_bfloat16 : require
|
| 4 |
+
|
| 5 |
+
void main()
|
| 6 |
+
{
|
| 7 |
+
}
|
|
@@ -33,6 +33,19 @@
|
|
| 33 |
#endif
|
| 34 |
#endif
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
#define QUANT_K_Q4_0 32
|
| 37 |
#define QUANT_R_Q4_0 2
|
| 38 |
|
|
@@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
|
|
| 1343 |
}
|
| 1344 |
#endif
|
| 1345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1346 |
#endif // !defined(GGML_TYPES_COMP)
|
|
|
|
| 33 |
#endif
|
| 34 |
#endif
|
| 35 |
|
| 36 |
+
#if defined(DATA_A_BF16)
|
| 37 |
+
#define QUANT_K 1
|
| 38 |
+
#define QUANT_R 1
|
| 39 |
+
|
| 40 |
+
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
| 41 |
+
#define A_TYPE uint16_t
|
| 42 |
+
#elif LOAD_VEC_A == 4
|
| 43 |
+
#define A_TYPE u16vec4
|
| 44 |
+
#elif LOAD_VEC_A == 8
|
| 45 |
+
#error unsupported
|
| 46 |
+
#endif
|
| 47 |
+
#endif
|
| 48 |
+
|
| 49 |
#define QUANT_K_Q4_0 32
|
| 50 |
#define QUANT_R_Q4_0 2
|
| 51 |
|
|
|
|
| 1356 |
}
|
| 1357 |
#endif
|
| 1358 |
|
| 1359 |
+
// returns the bfloat value in the low 16b.
|
| 1360 |
+
// See ggml_compute_fp32_to_bf16
|
| 1361 |
+
uint32_t fp32_to_bf16(float f)
|
| 1362 |
+
{
|
| 1363 |
+
uint32_t u = floatBitsToUint(f);
|
| 1364 |
+
u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
|
| 1365 |
+
return u;
|
| 1366 |
+
}
|
| 1367 |
+
|
| 1368 |
+
float bf16_to_fp32(uint32_t u)
|
| 1369 |
+
{
|
| 1370 |
+
return uintBitsToFloat(u << 16);
|
| 1371 |
+
}
|
| 1372 |
+
|
| 1373 |
#endif // !defined(GGML_TYPES_COMP)
|
|
@@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
|
|
| 63 |
"iq3_xxs",
|
| 64 |
"iq3_s",
|
| 65 |
"iq4_xs",
|
| 66 |
-
"iq4_nl"
|
|
|
|
| 67 |
};
|
| 68 |
|
| 69 |
namespace {
|
|
@@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 296 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
| 297 |
|
| 298 |
std::map<std::string, std::string> base_dict = {
|
| 299 |
-
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
| 300 |
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
| 301 |
};
|
| 302 |
std::string shader_name = "matmul";
|
|
@@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 318 |
|
| 319 |
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
// Shaders with f16 B_TYPE
|
| 322 |
-
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
| 323 |
-
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
for (const auto& tname : type_names) {
|
| 329 |
std::string load_vec_quant = "2";
|
|
@@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 332 |
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
| 333 |
load_vec_quant = "4";
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 336 |
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
| 337 |
-
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
|
| 338 |
// For aligned matmul loads
|
| 339 |
-
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
|
| 340 |
|
| 341 |
// don't generate f32 variants for coopmat2
|
| 342 |
if (!coopmat2) {
|
| 343 |
-
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 344 |
-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 345 |
}
|
| 346 |
|
| 347 |
if (tname != "f16" && tname != "f32") {
|
| 348 |
-
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 349 |
-
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 350 |
}
|
| 351 |
|
| 352 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 353 |
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
| 354 |
-
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
| 355 |
}
|
| 356 |
#endif
|
| 357 |
}
|
|
@@ -393,6 +430,7 @@ void process_shaders() {
|
|
| 393 |
if (tname == "f32") {
|
| 394 |
continue;
|
| 395 |
}
|
|
|
|
| 396 |
|
| 397 |
if (tname == "f16") {
|
| 398 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
@@ -417,12 +455,12 @@ void process_shaders() {
|
|
| 417 |
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
| 418 |
|
| 419 |
// Dequant shaders
|
| 420 |
-
if (tname != "f16") {
|
| 421 |
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
| 422 |
}
|
| 423 |
|
| 424 |
if (!string_ends_with(tname, "_k")) {
|
| 425 |
-
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
| 426 |
|
| 427 |
if (tname == "f16") {
|
| 428 |
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
|
@@ -447,9 +485,11 @@ void process_shaders() {
|
|
| 447 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 448 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
| 449 |
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
|
|
| 450 |
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 451 |
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
| 452 |
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
|
|
| 453 |
|
| 454 |
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
| 455 |
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
|
|
| 63 |
"iq3_xxs",
|
| 64 |
"iq3_s",
|
| 65 |
"iq4_xs",
|
| 66 |
+
"iq4_nl",
|
| 67 |
+
"bf16",
|
| 68 |
};
|
| 69 |
|
| 70 |
namespace {
|
|
|
|
| 297 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
| 298 |
|
| 299 |
std::map<std::string, std::string> base_dict = {
|
|
|
|
| 300 |
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
| 301 |
};
|
| 302 |
std::string shader_name = "matmul";
|
|
|
|
| 318 |
|
| 319 |
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 320 |
|
| 321 |
+
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
| 322 |
+
if (t == "bf16") {
|
| 323 |
+
// scalar path promotes to float
|
| 324 |
+
if (!coopmat && !coopmat2) {
|
| 325 |
+
return "float";
|
| 326 |
+
}
|
| 327 |
+
return "bfloat16_t";
|
| 328 |
+
}
|
| 329 |
+
if (coopmat2 || fp16) {
|
| 330 |
+
return "float16_t";
|
| 331 |
+
}
|
| 332 |
+
return "float";
|
| 333 |
+
};
|
| 334 |
+
|
| 335 |
// Shaders with f16 B_TYPE
|
| 336 |
+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
| 337 |
+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 338 |
+
|
| 339 |
+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 340 |
+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 341 |
+
|
| 342 |
+
// bf16
|
| 343 |
+
{
|
| 344 |
+
std::string load_vec_a_unaligned = "1";
|
| 345 |
+
// For aligned matmul loads
|
| 346 |
+
std::string load_vec_a = coopmat2 ? "1" : "4";
|
| 347 |
+
|
| 348 |
+
// scalar path promotes to float
|
| 349 |
+
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
| 350 |
|
| 351 |
+
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
| 352 |
+
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
| 353 |
+
if (!(coopmat || coopmat2))
|
| 354 |
+
#endif
|
| 355 |
+
{
|
| 356 |
+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 357 |
+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
|
| 361 |
for (const auto& tname : type_names) {
|
| 362 |
std::string load_vec_quant = "2";
|
|
|
|
| 365 |
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
| 366 |
load_vec_quant = "4";
|
| 367 |
|
| 368 |
+
if (tname == "bf16") {
|
| 369 |
+
continue;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 373 |
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
| 374 |
+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
| 375 |
// For aligned matmul loads
|
| 376 |
+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
| 377 |
|
| 378 |
// don't generate f32 variants for coopmat2
|
| 379 |
if (!coopmat2) {
|
| 380 |
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 381 |
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 382 |
}
|
| 383 |
|
| 384 |
if (tname != "f16" && tname != "f32") {
|
| 385 |
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 386 |
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 387 |
}
|
| 388 |
|
| 389 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 390 |
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
| 391 |
+
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
| 392 |
}
|
| 393 |
#endif
|
| 394 |
}
|
|
|
|
| 430 |
if (tname == "f32") {
|
| 431 |
continue;
|
| 432 |
}
|
| 433 |
+
if (tname == "bf16") continue;
|
| 434 |
|
| 435 |
if (tname == "f16") {
|
| 436 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
|
|
| 455 |
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
| 456 |
|
| 457 |
// Dequant shaders
|
| 458 |
+
if (tname != "f16" && tname != "bf16") {
|
| 459 |
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
| 460 |
}
|
| 461 |
|
| 462 |
if (!string_ends_with(tname, "_k")) {
|
| 463 |
+
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
| 464 |
|
| 465 |
if (tname == "f16") {
|
| 466 |
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
|
|
|
| 485 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 486 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
| 487 |
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
| 488 |
+
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
| 489 |
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 490 |
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
| 491 |
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
| 492 |
+
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
| 493 |
|
| 494 |
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
| 495 |
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|