jeffbolznv commited on
Commit
b21f8a1
·
1 Parent(s): 710fdcf

vulkan: Add bfloat16 support (llama/12554)

Browse files

* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O

ggml/src/ggml-vulkan/CMakeLists.txt CHANGED
@@ -71,6 +71,22 @@ if (Vulkan_FOUND)
71
  add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
72
  endif()
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
75
  target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
76
 
@@ -142,6 +158,7 @@ if (Vulkan_FOUND)
142
  -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
143
  -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
144
  -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
 
145
  BUILD_COMMAND ${CMAKE_COMMAND} --build .
146
  INSTALL_COMMAND ${CMAKE_COMMAND} --install .
147
  INSTALL_DIR ${CMAKE_BINARY_DIR}
 
71
  add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
72
  endif()
73
 
74
+ # Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
75
+ # If it's not, there will be an error to stderr.
76
+ # If it's supported, set a define to indicate that we should compile those shaders
77
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
78
+ OUTPUT_VARIABLE glslc_output
79
+ ERROR_VARIABLE glslc_error)
80
+
81
+ if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
82
+ message(STATUS "GL_EXT_bfloat16 not supported by glslc")
83
+ set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
84
+ else()
85
+ message(STATUS "GL_EXT_bfloat16 supported by glslc")
86
+ set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
87
+ add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
88
+ endif()
89
+
90
  target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
91
  target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
92
 
 
158
  -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
159
  -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
160
  -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
161
+ -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
162
  BUILD_COMMAND ${CMAKE_COMMAND} --build .
163
  INSTALL_COMMAND ${CMAKE_COMMAND} --install .
164
  INSTALL_DIR ${CMAKE_BINARY_DIR}
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -51,6 +51,24 @@
51
 
52
  #include "ggml-vulkan-shaders.hpp"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
55
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
56
  static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -266,8 +284,9 @@ struct vk_device_struct {
266
  bool subgroup_require_full_support;
267
 
268
  bool coopmat_support;
269
- bool coopmat_acc_f32_support;
270
- bool coopmat_acc_f16_support;
 
271
  uint32_t coopmat_m;
272
  uint32_t coopmat_n;
273
  uint32_t coopmat_k;
@@ -293,6 +312,7 @@ struct vk_device_struct {
293
 
294
  vk_matmul_pipeline pipeline_matmul_f32 {};
295
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
 
296
  vk_matmul_pipeline2 pipeline_matmul_f16;
297
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
298
 
@@ -301,6 +321,7 @@ struct vk_device_struct {
301
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
302
 
303
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
 
304
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
305
  vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
306
 
@@ -333,8 +354,8 @@ struct vk_device_struct {
333
  vk_pipeline pipeline_clamp_f32;
334
  vk_pipeline pipeline_pad_f32;
335
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
336
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
337
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
338
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
339
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
340
  vk_pipeline pipeline_norm_f32;
@@ -1811,6 +1832,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
1811
  if (!device->pipeline_matmul_id_f32) {
1812
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1813
  }
 
 
 
 
 
 
1814
 
1815
  std::vector<std::future<void>> compiles;
1816
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
@@ -1920,6 +1947,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1920
  CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1921
 
1922
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
 
 
 
 
 
1923
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1924
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1925
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1941,6 +1973,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1941
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1942
 
1943
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
 
 
 
 
 
1944
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1945
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1946
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1994,6 +2031,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1994
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1995
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1996
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 
 
 
 
 
1997
 
1998
  if (device->coopmat_acc_f16_support) {
1999
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2042,6 +2084,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2042
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2043
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2044
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
 
 
 
 
2045
 
2046
  if (device->coopmat_acc_f16_support) {
2047
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2124,6 +2171,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2124
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2125
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2126
 
 
 
2127
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2128
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2129
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2159,6 +2208,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2159
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2160
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2161
 
 
 
2162
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2163
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2164
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2211,6 +2262,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2211
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2212
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2213
 
 
 
2214
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2215
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2216
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2246,6 +2299,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2246
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2247
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2248
 
 
 
2249
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2250
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2251
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2266,8 +2321,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
2266
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2267
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2268
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2269
- #undef CREATE_MM
2270
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2271
 
2272
  // mul mat vec
2273
 
@@ -2286,6 +2359,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2286
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2287
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2288
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
 
2289
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2290
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2291
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2308,6 +2382,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2308
 
2309
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2310
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
 
2311
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2312
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2313
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2331,6 +2406,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2331
 
2332
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2333
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
 
2334
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2335
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2336
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
@@ -2376,6 +2452,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2376
  // get_rows
2377
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2378
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 
2379
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2380
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2381
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2393,6 +2470,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2393
 
2394
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2395
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 
2396
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2397
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2398
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2430,10 +2508,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2430
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2431
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2432
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
2433
 
2434
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2435
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2436
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
 
2437
  if (device->float_controls_rte_fp16) {
2438
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2439
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2601,6 +2682,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2601
  bool coopmat2_support = false;
2602
  device->coopmat_support = false;
2603
  device->integer_dot_product = false;
 
2604
 
2605
  for (const auto& properties : ext_props) {
2606
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2631,6 +2713,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2631
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2632
  device->integer_dot_product = true;
2633
  #endif
 
 
 
2634
  }
2635
  }
2636
 
@@ -2817,6 +2902,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2817
  }
2818
  #endif
2819
 
 
 
 
 
 
 
 
 
 
 
 
2820
  VkPhysicalDeviceMaintenance4Features maint4_features {};
2821
  maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2822
  if (maintenance4_support) {
@@ -3014,6 +3110,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
3014
  device->coopmat_int_n = prop.NSize;
3015
  device->coopmat_int_k = prop.KSize;
3016
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3017
  }
3018
 
3019
  if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
@@ -3021,11 +3136,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
3021
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
3022
  device->coopmat_support = false;
3023
  }
 
 
 
3024
  }
3025
 
3026
  if (device->coopmat_support) {
3027
  device_extensions.push_back("VK_KHR_cooperative_matrix");
3028
  }
 
 
 
 
 
3029
  #endif
3030
  device->name = GGML_VK_NAME + std::to_string(idx);
3031
 
@@ -3482,6 +3605,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3482
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3483
  return ctx->device->pipeline_matmul_f32_f16;
3484
  }
 
 
 
3485
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3486
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3487
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
@@ -3553,6 +3679,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
3553
  switch (a_type) {
3554
  case GGML_TYPE_F32:
3555
  case GGML_TYPE_F16:
 
3556
  case GGML_TYPE_Q4_0:
3557
  case GGML_TYPE_Q4_1:
3558
  case GGML_TYPE_Q5_0:
@@ -3585,6 +3712,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
3585
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3586
  return ctx->device->pipeline_matmul_id_f32;
3587
  }
 
 
 
3588
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3589
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3590
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
@@ -3638,6 +3768,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
3638
  switch (a_type) {
3639
  case GGML_TYPE_F32:
3640
  case GGML_TYPE_F16:
 
3641
  case GGML_TYPE_Q4_0:
3642
  case GGML_TYPE_Q4_1:
3643
  case GGML_TYPE_Q5_0:
@@ -4373,6 +4504,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4373
  return ctx->device->pipeline_cpy_f16_f16;
4374
  }
4375
  }
 
 
 
 
 
 
 
4376
  if (src->type == GGML_TYPE_F32) {
4377
  switch (to) {
4378
  case GGML_TYPE_Q4_0:
@@ -4500,8 +4638,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4500
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4501
  !ggml_vk_dim01_contiguous(src0);
4502
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
 
4503
  !ggml_vk_dim01_contiguous(src1);
4504
 
 
 
 
4505
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4506
 
4507
  bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
@@ -4511,25 +4653,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4511
 
4512
  if (mmp == nullptr) {
4513
  // Fall back to f16 dequant mul mat
4514
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4515
  quantize_y = false;
4516
  }
4517
 
4518
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4519
- const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
4520
 
4521
  if (qx_needs_dequant) {
4522
  // Fall back to dequant + f16 mulmat
4523
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
4524
  }
4525
 
4526
  // Not implemented
4527
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4528
 
4529
- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4530
  const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4531
 
4532
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4533
 
4534
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4535
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
@@ -4550,12 +4692,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4550
  vk_pipeline to_q8_1 = nullptr;
4551
 
4552
  if (x_non_contig) {
4553
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
4554
  } else {
4555
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
4556
  }
4557
  if (y_non_contig) {
4558
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
4559
  } else {
4560
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
4561
  }
@@ -5055,7 +5197,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
5055
  // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
5056
  // when ne12 and ne13 are one.
5057
  } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5058
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5059
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
5060
  } else {
5061
  ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -5123,27 +5265,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5123
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
5124
  !ggml_vk_dim01_contiguous(src0);
5125
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
 
5126
  !ggml_vk_dim01_contiguous(src1);
5127
 
 
 
 
5128
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
5129
 
5130
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
5131
 
5132
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
5133
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
5134
 
5135
  if (qx_needs_dequant) {
5136
  // Fall back to dequant + f16 mulmat
5137
- mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
5138
  }
5139
 
5140
  // Not implemented
5141
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
5142
 
5143
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
5144
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
5145
 
5146
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
5147
 
5148
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
5149
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
@@ -5162,12 +5308,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5162
  vk_pipeline to_fp16_vk_1 = nullptr;
5163
 
5164
  if (x_non_contig) {
5165
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
5166
  } else {
5167
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
5168
  }
5169
  if (y_non_contig) {
5170
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
5171
  } else {
5172
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
5173
  }
@@ -9295,6 +9441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9295
  switch (src0_type) {
9296
  case GGML_TYPE_F32:
9297
  case GGML_TYPE_F16:
 
9298
  case GGML_TYPE_Q4_0:
9299
  case GGML_TYPE_Q4_1:
9300
  case GGML_TYPE_Q5_0:
@@ -9330,10 +9477,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9330
  if (a->ne[3] != b->ne[3]) {
9331
  return false;
9332
  }
9333
- if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
9334
  !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
9335
  return false;
9336
  }
 
 
 
 
 
9337
 
9338
  return true;
9339
  } break;
@@ -9406,6 +9558,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9406
  switch (op->src[0]->type) {
9407
  case GGML_TYPE_F32:
9408
  case GGML_TYPE_F16:
 
9409
  case GGML_TYPE_Q4_0:
9410
  case GGML_TYPE_Q4_1:
9411
  case GGML_TYPE_Q5_0:
@@ -9436,6 +9589,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9436
  switch (src1_type) {
9437
  case GGML_TYPE_F32:
9438
  case GGML_TYPE_F16:
 
9439
  case GGML_TYPE_Q4_0:
9440
  case GGML_TYPE_Q4_1:
9441
  case GGML_TYPE_Q5_0:
 
51
 
52
  #include "ggml-vulkan-shaders.hpp"
53
 
54
+ // remove this once it's more widely available in the SDK
55
+ #if !defined(VK_KHR_shader_bfloat16)
56
+
57
+ #define VK_KHR_shader_bfloat16 1
58
+ #define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
59
+ #define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
60
+ #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
61
+ #define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
62
+
63
+ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
64
+ VkStructureType sType;
65
+ void* pNext;
66
+ VkBool32 shaderBFloat16Type;
67
+ VkBool32 shaderBFloat16DotProduct;
68
+ VkBool32 shaderBFloat16CooperativeMatrix;
69
+ } VkPhysicalDeviceShaderBfloat16FeaturesKHR;
70
+ #endif
71
+
72
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
73
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
74
  static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
 
284
  bool subgroup_require_full_support;
285
 
286
  bool coopmat_support;
287
+ bool coopmat_acc_f32_support {};
288
+ bool coopmat_acc_f16_support {};
289
+ bool coopmat_bf16_support {};
290
  uint32_t coopmat_m;
291
  uint32_t coopmat_n;
292
  uint32_t coopmat_k;
 
312
 
313
  vk_matmul_pipeline pipeline_matmul_f32 {};
314
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
315
+ vk_matmul_pipeline pipeline_matmul_bf16 {};
316
  vk_matmul_pipeline2 pipeline_matmul_f16;
317
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
318
 
 
321
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
322
 
323
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
324
+ vk_matmul_pipeline pipeline_matmul_id_bf16 {};
325
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
326
  vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
327
 
 
354
  vk_pipeline pipeline_clamp_f32;
355
  vk_pipeline pipeline_pad_f32;
356
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
357
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
358
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
359
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
360
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
361
  vk_pipeline pipeline_norm_f32;
 
1832
  if (!device->pipeline_matmul_id_f32) {
1833
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1834
  }
1835
+ if (!device->pipeline_matmul_bf16) {
1836
+ device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1837
+ }
1838
+ if (!device->pipeline_matmul_id_bf16) {
1839
+ device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1840
+ }
1841
 
1842
  std::vector<std::future<void>> compiles;
1843
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
 
1947
  CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1948
 
1949
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1950
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
1951
+ if (device->coopmat_bf16_support) {
1952
+ CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1953
+ }
1954
+ #endif
1955
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1956
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1957
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
1973
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1974
 
1975
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1976
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
1977
+ if (device->coopmat_bf16_support) {
1978
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1979
+ }
1980
+ #endif
1981
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1982
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1983
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 
2031
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2032
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2033
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2034
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2035
+ if (device->coopmat_bf16_support) {
2036
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2037
+ }
2038
+ #endif
2039
 
2040
  if (device->coopmat_acc_f16_support) {
2041
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2084
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2085
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2086
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2087
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2088
+ if (device->coopmat_bf16_support) {
2089
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2090
+ }
2091
+ #endif
2092
 
2093
  if (device->coopmat_acc_f16_support) {
2094
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2171
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2172
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2173
 
2174
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2175
+
2176
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2177
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2178
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2208
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2209
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2210
 
2211
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2212
+
2213
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2214
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2215
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2262
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2263
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2264
 
2265
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2266
+
2267
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2268
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2269
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2299
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2300
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2301
 
2302
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2303
+
2304
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2305
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2306
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2321
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2322
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2323
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2324
  }
2325
+ // reusing CREATE_MM from the fp32 path
2326
+ if ((device->coopmat2 || device->coopmat_support)
2327
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2328
+ && !device->coopmat_bf16_support
2329
+ #endif
2330
+ ) {
2331
+ // use scalar tile sizes
2332
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2333
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2334
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2335
+
2336
+ l_wg_denoms = {128, 128, 1 };
2337
+ m_wg_denoms = { 64, 64, 1 };
2338
+ s_wg_denoms = { 32, 32, 1 };
2339
+
2340
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2341
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2342
+ }
2343
+ #undef CREATE_MM
2344
 
2345
  // mul mat vec
2346
 
 
2359
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2360
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2361
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2362
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2363
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2364
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2365
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
 
2382
 
2383
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2384
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2385
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2386
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2387
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2388
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
 
2406
 
2407
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2408
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2409
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2410
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2411
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2412
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
 
2452
  // get_rows
2453
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2454
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2455
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2456
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2457
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2458
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
2470
 
2471
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2472
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2473
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2474
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2475
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2476
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
2508
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2509
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2510
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2511
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2512
 
2513
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2514
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2515
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2516
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2517
+
2518
  if (device->float_controls_rte_fp16) {
2519
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2520
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
 
2682
  bool coopmat2_support = false;
2683
  device->coopmat_support = false;
2684
  device->integer_dot_product = false;
2685
+ bool bfloat16_support = false;
2686
 
2687
  for (const auto& properties : ext_props) {
2688
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
 
2713
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2714
  device->integer_dot_product = true;
2715
  #endif
2716
+ } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
2717
+ !getenv("GGML_VK_DISABLE_BFLOAT16")) {
2718
+ bfloat16_support = true;
2719
  }
2720
  }
2721
 
 
2902
  }
2903
  #endif
2904
 
2905
+ #if defined(VK_KHR_shader_bfloat16)
2906
+ VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
2907
+ bfloat16_features.pNext = nullptr;
2908
+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
2909
+ if (bfloat16_support) {
2910
+ last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
2911
+ last_struct = (VkBaseOutStructure *)&bfloat16_features;
2912
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
2913
+ }
2914
+ #endif
2915
+
2916
  VkPhysicalDeviceMaintenance4Features maint4_features {};
2917
  maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2918
  if (maintenance4_support) {
 
3110
  device->coopmat_int_n = prop.NSize;
3111
  device->coopmat_int_k = prop.KSize;
3112
  }
3113
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3114
+ if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3115
+ prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3116
+ prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3117
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3118
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
3119
+ ) {
3120
+ // coopmat sizes not set yet
3121
+ if (device->coopmat_m == 0) {
3122
+ device->coopmat_bf16_support = true;
3123
+ device->coopmat_m = prop.MSize;
3124
+ device->coopmat_n = prop.NSize;
3125
+ device->coopmat_k = prop.KSize;
3126
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
3127
+ // Only enable if shape is identical
3128
+ device->coopmat_bf16_support = true;
3129
+ }
3130
+ }
3131
+ #endif
3132
  }
3133
 
3134
  if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
 
3136
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
3137
  device->coopmat_support = false;
3138
  }
3139
+ if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
3140
+ device->coopmat_bf16_support = false;
3141
+ }
3142
  }
3143
 
3144
  if (device->coopmat_support) {
3145
  device_extensions.push_back("VK_KHR_cooperative_matrix");
3146
  }
3147
+ #if defined(VK_KHR_shader_bfloat16)
3148
+ if (device->coopmat_bf16_support) {
3149
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
3150
+ }
3151
+ #endif
3152
  #endif
3153
  device->name = GGML_VK_NAME + std::to_string(idx);
3154
 
 
3605
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3606
  return ctx->device->pipeline_matmul_f32_f16;
3607
  }
3608
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3609
+ return ctx->device->pipeline_matmul_bf16;
3610
+ }
3611
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3612
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3613
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
 
3679
  switch (a_type) {
3680
  case GGML_TYPE_F32:
3681
  case GGML_TYPE_F16:
3682
+ case GGML_TYPE_BF16:
3683
  case GGML_TYPE_Q4_0:
3684
  case GGML_TYPE_Q4_1:
3685
  case GGML_TYPE_Q5_0:
 
3712
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3713
  return ctx->device->pipeline_matmul_id_f32;
3714
  }
3715
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3716
+ return ctx->device->pipeline_matmul_id_bf16;
3717
+ }
3718
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3719
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3720
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
 
3768
  switch (a_type) {
3769
  case GGML_TYPE_F32:
3770
  case GGML_TYPE_F16:
3771
+ case GGML_TYPE_BF16:
3772
  case GGML_TYPE_Q4_0:
3773
  case GGML_TYPE_Q4_1:
3774
  case GGML_TYPE_Q5_0:
 
4504
  return ctx->device->pipeline_cpy_f16_f16;
4505
  }
4506
  }
4507
+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
4508
+ if (contig) {
4509
+ return ctx->device->pipeline_contig_cpy_f32_bf16;
4510
+ } else {
4511
+ return ctx->device->pipeline_cpy_f32_bf16;
4512
+ }
4513
+ }
4514
  if (src->type == GGML_TYPE_F32) {
4515
  switch (to) {
4516
  case GGML_TYPE_Q4_0:
 
4638
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4639
  !ggml_vk_dim01_contiguous(src0);
4640
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
4641
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
4642
  !ggml_vk_dim01_contiguous(src1);
4643
 
4644
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
4645
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
4646
+
4647
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4648
 
4649
  bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
 
4653
 
4654
  if (mmp == nullptr) {
4655
  // Fall back to f16 dequant mul mat
4656
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
4657
  quantize_y = false;
4658
  }
4659
 
4660
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4661
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
4662
 
4663
  if (qx_needs_dequant) {
4664
  // Fall back to dequant + f16 mulmat
4665
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
4666
  }
4667
 
4668
  // Not implemented
4669
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4670
 
4671
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4672
  const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4673
 
4674
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4675
 
4676
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4677
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
 
4692
  vk_pipeline to_q8_1 = nullptr;
4693
 
4694
  if (x_non_contig) {
4695
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
4696
  } else {
4697
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
4698
  }
4699
  if (y_non_contig) {
4700
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
4701
  } else {
4702
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
4703
  }
 
5197
  // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
5198
  // when ne12 and ne13 are one.
5199
  } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5200
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
5201
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
5202
  } else {
5203
  ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
 
5265
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
5266
  !ggml_vk_dim01_contiguous(src0);
5267
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
5268
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
5269
  !ggml_vk_dim01_contiguous(src1);
5270
 
5271
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
5272
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
5273
+
5274
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
5275
 
5276
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
5277
 
5278
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
5279
+ const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
5280
 
5281
  if (qx_needs_dequant) {
5282
  // Fall back to dequant + f16 mulmat
5283
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
5284
  }
5285
 
5286
  // Not implemented
5287
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
5288
 
5289
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
5290
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
5291
 
5292
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
5293
 
5294
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
5295
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
 
5308
  vk_pipeline to_fp16_vk_1 = nullptr;
5309
 
5310
  if (x_non_contig) {
5311
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
5312
  } else {
5313
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
5314
  }
5315
  if (y_non_contig) {
5316
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
5317
  } else {
5318
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
5319
  }
 
9441
  switch (src0_type) {
9442
  case GGML_TYPE_F32:
9443
  case GGML_TYPE_F16:
9444
+ case GGML_TYPE_BF16:
9445
  case GGML_TYPE_Q4_0:
9446
  case GGML_TYPE_Q4_1:
9447
  case GGML_TYPE_Q5_0:
 
9477
  if (a->ne[3] != b->ne[3]) {
9478
  return false;
9479
  }
9480
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
9481
  !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
9482
  return false;
9483
  }
9484
+ if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
9485
+ // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
9486
+ // So don't support this combination for now.
9487
+ return false;
9488
+ }
9489
 
9490
  return true;
9491
  } break;
 
9558
  switch (op->src[0]->type) {
9559
  case GGML_TYPE_F32:
9560
  case GGML_TYPE_F16:
9561
+ case GGML_TYPE_BF16:
9562
  case GGML_TYPE_Q4_0:
9563
  case GGML_TYPE_Q4_1:
9564
  case GGML_TYPE_Q5_0:
 
9589
  switch (src1_type) {
9590
  case GGML_TYPE_F32:
9591
  case GGML_TYPE_F16:
9592
+ case GGML_TYPE_BF16:
9593
  case GGML_TYPE_Q4_0:
9594
  case GGML_TYPE_Q4_1:
9595
  case GGML_TYPE_Q5_0:
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt CHANGED
@@ -12,6 +12,9 @@ endif()
12
  if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
13
  add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
14
  endif()
 
 
 
15
  set(TARGET vulkan-shaders-gen)
16
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
17
  install(TARGETS ${TARGET} RUNTIME)
 
12
  if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
13
  add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
14
  endif()
15
+ if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
16
+ add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
17
+ endif()
18
  set(TARGET vulkan-shaders-gen)
19
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
20
  install(TARGETS ${TARGET} RUNTIME)
ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp CHANGED
@@ -18,7 +18,11 @@ void main() {
18
  // fast path for when all four iterations are in-bounds
19
  if (idx + (num_iter-1)*num_threads < p.ne) {
20
  [[unroll]] for (uint i = 0; i < num_iter; ++i) {
21
- #ifndef OPTIMIZATION_ERROR_WORKAROUND
 
 
 
 
22
  data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
23
  #else
24
  data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@@ -31,7 +35,10 @@ void main() {
31
  continue;
32
  }
33
 
34
- #ifndef OPTIMIZATION_ERROR_WORKAROUND
 
 
 
35
  data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
36
  #else
37
  data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
 
18
  // fast path for when all four iterations are in-bounds
19
  if (idx + (num_iter-1)*num_threads < p.ne) {
20
  [[unroll]] for (uint i = 0; i < num_iter; ++i) {
21
+
22
+ #if defined(DATA_D_BF16)
23
+ float f = float(data_a[get_aoffset() + idx]);
24
+ data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
25
+ #elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
26
  data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
27
  #else
28
  data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
 
35
  continue;
36
  }
37
 
38
+ #if defined(DATA_D_BF16)
39
+ float f = float(data_a[get_aoffset() + idx]);
40
+ data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
41
+ #elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
42
  data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
43
  #else
44
  data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
ggml/src/ggml-vulkan/vulkan-shaders/copy.comp CHANGED
@@ -12,7 +12,10 @@ void main() {
12
  return;
13
  }
14
 
15
- #ifndef OPTIMIZATION_ERROR_WORKAROUND
 
 
 
16
  data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
17
  #else
18
  data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
 
12
  return;
13
  }
14
 
15
+ #if defined(DATA_D_BF16)
16
+ float f = float(data_a[get_aoffset() + src0_idx(idx)]);
17
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
18
+ #elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
19
  data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
20
  #else
21
  data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp CHANGED
@@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
23
  }
24
  #endif
25
 
 
 
 
 
 
 
26
  #if defined(DATA_A_Q4_0)
27
  vec2 dequantize(uint ib, uint iqs, uint a_offset) {
28
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
428
  }
429
  #endif
430
 
431
- #if defined(DATA_A_F32) || defined(DATA_A_F16)
432
  vec2 get_dm(uint ib, uint a_offset) {
433
  return vec2(0, 0);
434
  }
 
23
  }
24
  #endif
25
 
26
+ #if defined(DATA_A_BF16)
27
+ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
28
+ return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
29
+ }
30
+ #endif
31
+
32
  #if defined(DATA_A_Q4_0)
33
  vec2 dequantize(uint ib, uint iqs, uint a_offset) {
34
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
 
434
  }
435
  #endif
436
 
437
+ #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
438
  vec2 get_dm(uint ib, uint a_offset) {
439
  return vec2(0, 0);
440
  }
ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp CHANGED
@@ -20,9 +20,14 @@ void main() {
20
  const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
21
  const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
22
 
 
 
 
 
 
23
  #ifndef OPTIMIZATION_ERROR_WORKAROUND
24
- data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
25
  #else
26
- data_d[d_offset + i00] = data_a[a_offset + i00];
27
  #endif
28
  }
 
20
  const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
21
  const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
22
 
23
+ #if defined(DATA_A_BF16)
24
+ FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
25
+ #else
26
+ FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
27
+ #endif
28
  #ifndef OPTIMIZATION_ERROR_WORKAROUND
29
+ data_d[d_offset + i00] = D_TYPE(v);
30
  #else
31
+ data_d[d_offset + i00] = D_TYPE(v);
32
  #endif
33
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp CHANGED
@@ -6,7 +6,7 @@
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
- #if !defined(DATA_A_F32) && !defined(DATA_A_F16)
10
  #define K_PER_ITER 8
11
  #else
12
  #define K_PER_ITER 2
 
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
+ #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
10
  #define K_PER_ITER 8
11
  #else
12
  #define K_PER_ITER 2
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp CHANGED
@@ -10,6 +10,10 @@
10
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
11
  #endif
12
 
 
 
 
 
13
  #ifdef COOPMAT
14
  #extension GL_KHR_cooperative_matrix : enable
15
  #extension GL_KHR_memory_scope_semantics : enable
@@ -29,6 +33,10 @@
29
  #define LOAD_VEC_B 1
30
  #endif
31
 
 
 
 
 
32
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
33
 
34
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -202,8 +210,8 @@ void main() {
202
  #endif
203
 
204
  #ifdef COOPMAT
205
- coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
206
- coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
207
  coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
208
 
209
  [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@@ -248,6 +256,21 @@ void main() {
248
  buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
249
  }
250
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  #elif defined(DATA_A_Q4_0)
252
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
253
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
@@ -695,13 +718,13 @@ void main() {
695
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
696
  #endif
697
  const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
698
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
699
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
700
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
701
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
702
  #elif !MUL_MAT_ID
703
  if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
704
- buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
705
  } else {
706
  buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
707
  }
@@ -709,7 +732,7 @@ void main() {
709
  const uint row_i = ic * BN + loadc_b + l;
710
  if (row_i < _ne1) {
711
  const u16vec2 row_idx = row_ids[row_i];
712
- buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
713
  } else {
714
  buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
715
  }
 
10
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
11
  #endif
12
 
13
+ #if defined(DATA_A_BF16) && defined(COOPMAT)
14
+ #extension GL_EXT_bfloat16 : enable
15
+ #endif
16
+
17
  #ifdef COOPMAT
18
  #extension GL_KHR_cooperative_matrix : enable
19
  #extension GL_KHR_memory_scope_semantics : enable
 
33
  #define LOAD_VEC_B 1
34
  #endif
35
 
36
+ #if !defined(TO_FLOAT_TYPE)
37
+ #define TO_FLOAT_TYPE FLOAT_TYPE
38
+ #endif
39
+
40
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
41
 
42
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 
210
  #endif
211
 
212
  #ifdef COOPMAT
213
+ coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
214
+ coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
215
  coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
216
 
217
  [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
 
256
  buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
257
  }
258
  #endif
259
+ #elif defined(DATA_A_BF16)
260
+ #if LOAD_VEC_A == 4
261
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
262
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
263
+ buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
264
+ buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
265
+ buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
266
+ buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
267
+ #else
268
+ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
269
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
270
+ } else {
271
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
272
+ }
273
+ #endif
274
  #elif defined(DATA_A_Q4_0)
275
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
276
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
 
718
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
719
  #endif
720
  const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
721
+ buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
722
+ buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
723
+ buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
724
+ buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
725
  #elif !MUL_MAT_ID
726
  if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
727
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
728
  } else {
729
  buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
730
  }
 
732
  const uint row_i = ic * BN + loadc_b + l;
733
  if (row_i < _ne1) {
734
  const u16vec2 row_idx = row_ids[row_i];
735
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
736
  } else {
737
  buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
738
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp CHANGED
@@ -14,6 +14,9 @@
14
  #extension GL_EXT_buffer_reference : enable
15
  #extension GL_KHR_shader_subgroup_ballot : enable
16
  #extension GL_KHR_shader_subgroup_vote : enable
 
 
 
17
 
18
  #include "types.comp"
19
 
@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
80
  #define store_scales(a)
81
  #endif
82
 
 
 
 
 
 
 
83
  #ifdef MUL_MAT_ID
84
  layout (binding = 3) readonly buffer IDS {int data_ids[];};
85
 
@@ -271,8 +280,8 @@ void main() {
271
 
272
  // Manually partial unroll
273
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
274
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
276
 
277
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -286,8 +295,8 @@ void main() {
286
  store_scales(tid);
287
  }
288
  while (block_k < end_k) {
289
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
290
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
291
 
292
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
293
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -310,8 +319,8 @@ void main() {
310
 
311
  // Manually partial unroll
312
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
313
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
315
 
316
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
317
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -325,8 +334,8 @@ void main() {
325
  store_scales(tid);
326
  }
327
  while (block_k < end_k) {
328
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
329
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
330
 
331
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
332
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -350,8 +359,8 @@ void main() {
350
 
351
  // Manually partial unroll
352
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
353
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
354
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
355
 
356
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
357
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -365,8 +374,8 @@ void main() {
365
  store_scales(tid);
366
  }
367
  while (block_k < end_k) {
368
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
369
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
370
 
371
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
372
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -405,8 +414,8 @@ void main() {
405
  fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
406
  }
407
 
408
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
409
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
410
 
411
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
412
  #ifdef MUL_MAT_ID
 
14
  #extension GL_EXT_buffer_reference : enable
15
  #extension GL_KHR_shader_subgroup_ballot : enable
16
  #extension GL_KHR_shader_subgroup_vote : enable
17
+ #ifdef DATA_A_BF16
18
+ #extension GL_EXT_bfloat16 : enable
19
+ #endif
20
 
21
  #include "types.comp"
22
 
 
83
  #define store_scales(a)
84
  #endif
85
 
86
+ #if defined(DATA_A_BF16)
87
+ #define MAT_TYPE bfloat16_t
88
+ #else
89
+ #define MAT_TYPE FLOAT_TYPE
90
+ #endif
91
+
92
  #ifdef MUL_MAT_ID
93
  layout (binding = 3) readonly buffer IDS {int data_ids[];};
94
 
 
280
 
281
  // Manually partial unroll
282
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
283
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
285
 
286
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
287
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
 
295
  store_scales(tid);
296
  }
297
  while (block_k < end_k) {
298
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
299
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
300
 
301
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
302
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
 
319
 
320
  // Manually partial unroll
321
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
322
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
323
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
324
 
325
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
326
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
 
334
  store_scales(tid);
335
  }
336
  while (block_k < end_k) {
337
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
338
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
339
 
340
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
341
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
 
359
 
360
  // Manually partial unroll
361
  [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
362
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
363
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
364
 
365
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
366
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
 
374
  store_scales(tid);
375
  }
376
  while (block_k < end_k) {
377
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
378
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
379
 
380
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
381
  coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
 
414
  fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415
  }
416
 
417
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419
 
420
  coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
421
  #ifdef MUL_MAT_ID
ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #version 460
2
+
3
+ #extension GL_EXT_bfloat16 : require
4
+
5
+ void main()
6
+ {
7
+ }
ggml/src/ggml-vulkan/vulkan-shaders/types.comp CHANGED
@@ -33,6 +33,19 @@
33
  #endif
34
  #endif
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  #define QUANT_K_Q4_0 32
37
  #define QUANT_R_Q4_0 2
38
 
@@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
1343
  }
1344
  #endif
1345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1346
  #endif // !defined(GGML_TYPES_COMP)
 
33
  #endif
34
  #endif
35
 
36
+ #if defined(DATA_A_BF16)
37
+ #define QUANT_K 1
38
+ #define QUANT_R 1
39
+
40
+ #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
41
+ #define A_TYPE uint16_t
42
+ #elif LOAD_VEC_A == 4
43
+ #define A_TYPE u16vec4
44
+ #elif LOAD_VEC_A == 8
45
+ #error unsupported
46
+ #endif
47
+ #endif
48
+
49
  #define QUANT_K_Q4_0 32
50
  #define QUANT_R_Q4_0 2
51
 
 
1356
  }
1357
  #endif
1358
 
1359
+ // returns the bfloat value in the low 16b.
1360
+ // See ggml_compute_fp32_to_bf16
1361
+ uint32_t fp32_to_bf16(float f)
1362
+ {
1363
+ uint32_t u = floatBitsToUint(f);
1364
+ u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
1365
+ return u;
1366
+ }
1367
+
1368
+ float bf16_to_fp32(uint32_t u)
1369
+ {
1370
+ return uintBitsToFloat(u << 16);
1371
+ }
1372
+
1373
  #endif // !defined(GGML_TYPES_COMP)
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
63
  "iq3_xxs",
64
  "iq3_s",
65
  "iq4_xs",
66
- "iq4_nl"
 
67
  };
68
 
69
  namespace {
@@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
296
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
297
 
298
  std::map<std::string, std::string> base_dict = {
299
- {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
300
  {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
301
  };
302
  std::string shader_name = "matmul";
@@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
318
 
319
  const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  // Shaders with f16 B_TYPE
322
- string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
323
- string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
326
- string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 
 
 
 
 
 
 
327
 
328
  for (const auto& tname : type_names) {
329
  std::string load_vec_quant = "2";
@@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
332
  else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
333
  load_vec_quant = "4";
334
 
 
 
 
 
335
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
336
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
337
- std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
338
  // For aligned matmul loads
339
- std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
340
 
341
  // don't generate f32 variants for coopmat2
342
  if (!coopmat2) {
343
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
344
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
345
  }
346
 
347
  if (tname != "f16" && tname != "f32") {
348
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
349
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
350
  }
351
 
352
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
353
  if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
354
- string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
355
  }
356
  #endif
357
  }
@@ -393,6 +430,7 @@ void process_shaders() {
393
  if (tname == "f32") {
394
  continue;
395
  }
 
396
 
397
  if (tname == "f16") {
398
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
@@ -417,12 +455,12 @@ void process_shaders() {
417
  string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
418
 
419
  // Dequant shaders
420
- if (tname != "f16") {
421
  string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
422
  }
423
 
424
  if (!string_ends_with(tname, "_k")) {
425
- shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
426
 
427
  if (tname == "f16") {
428
  string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
@@ -447,9 +485,11 @@ void process_shaders() {
447
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
449
  string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 
450
  string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
  string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
452
  string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 
453
 
454
  for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
455
  string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
63
  "iq3_xxs",
64
  "iq3_s",
65
  "iq4_xs",
66
+ "iq4_nl",
67
+ "bf16",
68
  };
69
 
70
  namespace {
 
297
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
298
 
299
  std::map<std::string, std::string> base_dict = {
 
300
  {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
301
  };
302
  std::string shader_name = "matmul";
 
318
 
319
  const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
320
 
321
+ auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
322
+ if (t == "bf16") {
323
+ // scalar path promotes to float
324
+ if (!coopmat && !coopmat2) {
325
+ return "float";
326
+ }
327
+ return "bfloat16_t";
328
+ }
329
+ if (coopmat2 || fp16) {
330
+ return "float16_t";
331
+ }
332
+ return "float";
333
+ };
334
+
335
  // Shaders with f16 B_TYPE
336
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
337
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
338
+
339
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
340
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
341
+
342
+ // bf16
343
+ {
344
+ std::string load_vec_a_unaligned = "1";
345
+ // For aligned matmul loads
346
+ std::string load_vec_a = coopmat2 ? "1" : "4";
347
+
348
+ // scalar path promotes to float
349
+ std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
350
 
351
+ // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
352
+ #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
353
+ if (!(coopmat || coopmat2))
354
+ #endif
355
+ {
356
+ string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
357
+ string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
358
+ }
359
+ }
360
 
361
  for (const auto& tname : type_names) {
362
  std::string load_vec_quant = "2";
 
365
  else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
366
  load_vec_quant = "4";
367
 
368
+ if (tname == "bf16") {
369
+ continue;
370
+ }
371
+
372
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
373
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
374
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
375
  // For aligned matmul loads
376
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
377
 
378
  // don't generate f32 variants for coopmat2
379
  if (!coopmat2) {
380
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
381
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
382
  }
383
 
384
  if (tname != "f16" && tname != "f32") {
385
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
386
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
387
  }
388
 
389
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
390
  if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
391
+ string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
392
  }
393
  #endif
394
  }
 
430
  if (tname == "f32") {
431
  continue;
432
  }
433
+ if (tname == "bf16") continue;
434
 
435
  if (tname == "f16") {
436
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
 
455
  string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
456
 
457
  // Dequant shaders
458
+ if (tname != "f16" && tname != "bf16") {
459
  string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
460
  }
461
 
462
  if (!string_ends_with(tname, "_k")) {
463
+ shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
464
 
465
  if (tname == "f16") {
466
  string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
 
485
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
486
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
487
  string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
488
+ string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
489
  string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
490
  string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
491
  string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
492
+ string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
493
 
494
  for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
495
  string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});