xctan ggerganov commited on
Commit
8c833e9
·
1 Parent(s): 8f2e8d6

ggml-cpu : split arch-specific implementations (llama/13892)

Browse files

* move ggml-cpu-aarch64 to repack

* split quantize_row_q8_0/1

* split helper functions

* split ggml_vec_dot_q4_0_q8_0

* split ggml_vec_dot_q4_1_q8_1

* split ggml_vec_dot_q5_0_q8_0

* split ggml_vec_dot_q5_1_q8_1

* split ggml_vec_dot_q8_0_q8_0

* split ggml_vec_dot_tq1_0_q8_K

* split ggml_vec_dot_tq2_0_q8_K

* split ggml_vec_dot_q2_K_q8_K

* split ggml_vec_dot_q3_K_q8_K

* split ggml_vec_dot_q4_K_q8_K

* split ggml_vec_dot_q5_K_q8_K

* split ggml_vec_dot_q6_K_q8_K

* split ggml_vec_dot_iq2_xxs_q8_K

* split ggml_vec_dot_iq2_xs_q8_K

* split ggml_vec_dot_iq2_s_q8_K

* split ggml_vec_dot_iq3_xxs_q8_K

* split ggml_vec_dot_iq3_s_q8_K

* split ggml_vec_dot_iq1_s_q8_K

* split ggml_vec_dot_iq1_m_q8_K

* split ggml_vec_dot_iq4_nl_q8_0

* split ggml_vec_dot_iq4_xs_q8_K

* fix typos

* fix missing prototypes

* rename ggml-cpu-quants.c

* rename ggml-cpu-traits

* rename arm folder

* move cpu-feats-x86.cpp

* rename ggml-cpu-hbm

* update arm detection macro in quants.c

* move iq quant tables

* split ggml_quantize_mat_q8_0/K

* split ggml_gemv_*

* split ggml_gemm_*

* rename namespace aarch64 to repack

* use weak aliases to replace test macros

* rename GGML_CPU_AARCH64 to GGML_CPU_REPACK

* rename more aarch64 to repack

* clean up rebase leftover

* fix compilation errors

* remove trailing spaces

* try to fix clang compilation errors

* try to fix clang compilation errors again

* try to fix clang compilation errors, 3rd attempt

* try to fix clang compilation errors, 4th attempt

* try to fix clang compilation errors, 5th attempt

* try to fix clang compilation errors, 6th attempt

* try to fix clang compilation errors, 7th attempt

* try to fix clang compilation errors, 8th attempt

* try to fix clang compilation errors, 9th attempt

* more cleanup

* fix compilation errors

* fix apple targets

* fix a typo in arm version of ggml_vec_dot_q4_K_q8_K

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/CMakeLists.txt CHANGED
@@ -105,7 +105,7 @@ message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
105
  message(DEBUG "INS_ENB : ${INS_ENB}")
106
 
107
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
108
- option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
109
  option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
110
  option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
111
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
 
105
  message(DEBUG "INS_ENB : ${INS_ENB}")
106
 
107
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
108
+ option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
109
  option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
110
  option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
111
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
ggml/src/ggml-common.h CHANGED
@@ -1074,6 +1074,10 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
1074
  0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
1075
  GGML_TABLE_END()
1076
 
 
 
 
 
1077
  #define NGRID_IQ1S 2048
1078
  #define IQ1S_DELTA 0.125f
1079
  #define IQ1M_DELTA 0.125f
 
1074
  0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
1075
  GGML_TABLE_END()
1076
 
1077
+ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
1078
+ -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
1079
+ GGML_TABLE_END()
1080
+
1081
  #define NGRID_IQ1S 2048
1082
  #define IQ1S_DELTA 0.125f
1083
  #define IQ1M_DELTA 0.125f
ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -10,14 +10,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
10
  list (APPEND GGML_CPU_SOURCES
11
  ggml-cpu/ggml-cpu.c
12
  ggml-cpu/ggml-cpu.cpp
13
- ggml-cpu/ggml-cpu-aarch64.cpp
14
- ggml-cpu/ggml-cpu-aarch64.h
15
- ggml-cpu/ggml-cpu-hbm.cpp
16
- ggml-cpu/ggml-cpu-hbm.h
17
- ggml-cpu/ggml-cpu-quants.c
18
- ggml-cpu/ggml-cpu-quants.h
19
- ggml-cpu/ggml-cpu-traits.cpp
20
- ggml-cpu/ggml-cpu-traits.h
21
  ggml-cpu/amx/amx.cpp
22
  ggml-cpu/amx/amx.h
23
  ggml-cpu/amx/mmq.cpp
@@ -84,6 +84,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
84
 
85
  if (GGML_SYSTEM_ARCH STREQUAL "ARM")
86
  message(STATUS "ARM detected")
 
 
 
 
 
87
  if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
88
  message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
89
  else()
@@ -167,6 +172,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
167
  endif()
168
  elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
169
  message(STATUS "x86 detected")
 
 
 
 
 
170
  if (MSVC)
171
  # instruction set detection for MSVC only
172
  if (GGML_NATIVE)
@@ -302,7 +312,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
302
  # Since multiple variants of the CPU backend may be included in the same
303
  # build, using set_source_files_properties() to set the arch flags is not possible
304
  set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
305
- add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp)
306
  target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
307
  target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
308
  target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
@@ -311,6 +321,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
311
  endif()
312
  elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
313
  message(STATUS "PowerPC detected")
 
314
  if (GGML_NATIVE)
315
  if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
316
  file(READ "/proc/cpuinfo" POWER10_M)
@@ -338,6 +349,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
338
  endif()
339
  elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64")
340
  message(STATUS "loongarch64 detected")
 
 
341
  list(APPEND ARCH_FLAGS -march=loongarch64)
342
  if (GGML_LASX)
343
  list(APPEND ARCH_FLAGS -mlasx)
@@ -347,6 +360,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
347
  endif()
348
  elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
349
  message(STATUS "riscv64 detected")
 
 
 
 
350
  if (GGML_RVV)
351
  if (GGML_XTHEADVECTOR)
352
  list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
@@ -358,6 +375,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
358
  endif()
359
  elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
360
  message(STATUS "s390x detected")
 
361
  file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
362
  string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
363
 
@@ -381,12 +399,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
381
  if (GGML_VXE)
382
  list(APPEND ARCH_FLAGS -mvx -mzvector)
383
  endif()
 
 
 
384
  else()
385
- message(STATUS "Unknown architecture")
 
386
  endif()
387
 
388
- if (GGML_CPU_AARCH64)
389
- target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
390
  endif()
391
 
392
  if (GGML_CPU_KLEIDIAI)
 
10
  list (APPEND GGML_CPU_SOURCES
11
  ggml-cpu/ggml-cpu.c
12
  ggml-cpu/ggml-cpu.cpp
13
+ ggml-cpu/repack.cpp
14
+ ggml-cpu/repack.h
15
+ ggml-cpu/hbm.cpp
16
+ ggml-cpu/hbm.h
17
+ ggml-cpu/quants.c
18
+ ggml-cpu/quants.h
19
+ ggml-cpu/traits.cpp
20
+ ggml-cpu/traits.h
21
  ggml-cpu/amx/amx.cpp
22
  ggml-cpu/amx/amx.h
23
  ggml-cpu/amx/mmq.cpp
 
84
 
85
  if (GGML_SYSTEM_ARCH STREQUAL "ARM")
86
  message(STATUS "ARM detected")
87
+ list(APPEND GGML_CPU_SOURCES
88
+ ggml-cpu/arch/arm/quants.c
89
+ ggml-cpu/arch/arm/repack.cpp
90
+ )
91
+
92
  if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
93
  message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
94
  else()
 
172
  endif()
173
  elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
174
  message(STATUS "x86 detected")
175
+ list(APPEND GGML_CPU_SOURCES
176
+ ggml-cpu/arch/x86/quants.c
177
+ ggml-cpu/arch/x86/repack.cpp
178
+ )
179
+
180
  if (MSVC)
181
  # instruction set detection for MSVC only
182
  if (GGML_NATIVE)
 
312
  # Since multiple variants of the CPU backend may be included in the same
313
  # build, using set_source_files_properties() to set the arch flags is not possible
314
  set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
315
+ add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/x86/cpu-feats.cpp)
316
  target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
317
  target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
318
  target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
 
321
  endif()
322
  elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
323
  message(STATUS "PowerPC detected")
324
+ list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/powerpc/quants.c)
325
  if (GGML_NATIVE)
326
  if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
327
  file(READ "/proc/cpuinfo" POWER10_M)
 
349
  endif()
350
  elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64")
351
  message(STATUS "loongarch64 detected")
352
+ list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/loongarch/quants.c)
353
+
354
  list(APPEND ARCH_FLAGS -march=loongarch64)
355
  if (GGML_LASX)
356
  list(APPEND ARCH_FLAGS -mlasx)
 
360
  endif()
361
  elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
362
  message(STATUS "riscv64 detected")
363
+ list(APPEND GGML_CPU_SOURCES
364
+ ggml-cpu/arch/riscv/quants.c
365
+ ggml-cpu/arch/riscv/repack.cpp
366
+ )
367
  if (GGML_RVV)
368
  if (GGML_XTHEADVECTOR)
369
  list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
 
375
  endif()
376
  elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
377
  message(STATUS "s390x detected")
378
+ list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
379
  file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
380
  string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
381
 
 
399
  if (GGML_VXE)
400
  list(APPEND ARCH_FLAGS -mvx -mzvector)
401
  endif()
402
+ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
403
+ message(STATUS "Wasm detected")
404
+ list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
405
  else()
406
+ message(WARNING "Unknown CPU architecture. Falling back to generic implementations.")
407
+ list(APPEND ARCH_FLAGS -DGGML_CPU_GENERIC)
408
  endif()
409
 
410
+ if (GGML_CPU_REPACK)
411
+ target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_REPACK)
412
  endif()
413
 
414
  if (GGML_CPU_KLEIDIAI)
ggml/src/ggml-cpu/amx/amx.cpp CHANGED
@@ -5,7 +5,7 @@
5
  #include "ggml-backend.h"
6
  #include "ggml-impl.h"
7
  #include "ggml-cpu.h"
8
- #include "ggml-cpu-traits.h"
9
 
10
  #if defined(__gnu_linux__)
11
  #include <sys/syscall.h>
 
5
  #include "ggml-backend.h"
6
  #include "ggml-impl.h"
7
  #include "ggml-cpu.h"
8
+ #include "traits.h"
9
 
10
  #if defined(__gnu_linux__)
11
  #include <sys/syscall.h>
ggml/src/ggml-cpu/amx/mmq.cpp CHANGED
@@ -8,7 +8,7 @@
8
  #include "mmq.h"
9
  #include "ggml-impl.h"
10
  #include "ggml-cpu-impl.h"
11
- #include "ggml-cpu-quants.h"
12
  #include "ggml-quants.h"
13
  #include <algorithm>
14
  #include <type_traits>
 
8
  #include "mmq.h"
9
  #include "ggml-impl.h"
10
  #include "ggml-cpu-impl.h"
11
+ #include "quants.h"
12
  #include "ggml-quants.h"
13
  #include <algorithm>
14
  #include <type_traits>
ggml/src/ggml-cpu/arch/arm/quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/arch/arm/repack.cpp ADDED
@@ -0,0 +1,2174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_CPP
2
+ #define GGML_COMMON_DECL_CPP
3
+ #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
5
+
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+ #include "ggml-cpu-impl.h"
9
+ #include "traits.h"
10
+
11
+ #include <cmath>
12
+ #include <cstring>
13
+ #include <cassert>
14
+ #include <cstdlib> // for qsort
15
+ #include <cstdio> // for GGML_ASSERT
16
+
17
+ #define GGML_CPU_CLANG_WORKAROUND
18
+ #include "../../repack.h"
19
+
20
+ #if defined(__GNUC__)
21
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
22
+ #endif
23
+
24
+ #define UNUSED GGML_UNUSED
25
+
26
+ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
27
+ assert(QK8_0 == 32);
28
+ assert(k % QK8_0 == 0);
29
+ const int nb = k / QK8_0;
30
+
31
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
32
+
33
+ #if defined(__ARM_NEON)
34
+ float32x4_t srcv[4][8];
35
+ float id[4];
36
+
37
+ for (int i = 0; i < nb; i++) {
38
+ float32x4_t asrcv[8];
39
+ float32x4_t amaxv[8];
40
+
41
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
42
+ for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
43
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
44
+
45
+ for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
46
+ for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
47
+ for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
48
+
49
+ const float amax = vmaxvq_f32(amaxv[0]);
50
+
51
+ const float d = amax / ((1 << 7) - 1);
52
+ id[row_iter] = d ? 1.0f / d : 0.0f;
53
+
54
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
55
+ }
56
+
57
+ for (int j = 0; j < 8; j++) {
58
+ float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
59
+ int32x4_t vi = vcvtnq_s32_f32(v);
60
+ y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
61
+ y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
62
+ y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
63
+ y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
64
+
65
+ v = vmulq_n_f32(srcv[1][j], id[1]);
66
+ vi = vcvtnq_s32_f32(v);
67
+ y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
68
+ y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
69
+ y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
70
+ y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
71
+
72
+ v = vmulq_n_f32(srcv[2][j], id[2]);
73
+ vi = vcvtnq_s32_f32(v);
74
+ y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
75
+ y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
76
+ y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
77
+ y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
78
+
79
+ v = vmulq_n_f32(srcv[3][j], id[3]);
80
+ vi = vcvtnq_s32_f32(v);
81
+ y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
82
+ y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
83
+ y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
84
+ y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
85
+ }
86
+ }
87
+ #else
88
+ // scalar
89
+ const int blck_size_interleave = 4;
90
+ float srcv[4][QK8_0];
91
+ float id[4];
92
+
93
+ for (int i = 0; i < nb; i++) {
94
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
95
+ float amax = 0.0f; // absolute max
96
+
97
+ for (int j = 0; j < QK8_0; j++) {
98
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
99
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
100
+ }
101
+
102
+ const float d = amax / ((1 << 7) - 1);
103
+ id[row_iter] = d ? 1.0f / d : 0.0f;
104
+
105
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
106
+ }
107
+
108
+ for (int j = 0; j < QK8_0 * 4; j++) {
109
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
110
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
111
+ src_offset += (j % blck_size_interleave);
112
+
113
+ float x0 = srcv[src_id][src_offset] * id[src_id];
114
+ y[i].qs[j] = roundf(x0);
115
+ }
116
+ }
117
+ #endif
118
+ }
119
+
120
+ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
121
+ assert(QK8_0 == 32);
122
+ assert(k % QK8_0 == 0);
123
+ const int nb = k / QK8_0;
124
+
125
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
126
+
127
+ #if defined(__ARM_NEON)
128
+ float32x4_t srcv[4][8];
129
+ float id[4];
130
+
131
+ for (int i = 0; i < nb; i++) {
132
+ float32x4_t asrcv[8];
133
+ float32x4_t amaxv[8];
134
+
135
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
136
+ for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
137
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
138
+
139
+ for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
140
+ for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
141
+ for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
142
+
143
+ const float amax = vmaxvq_f32(amaxv[0]);
144
+
145
+ const float d = amax / ((1 << 7) - 1);
146
+ id[row_iter] = d ? 1.0f / d : 0.0f;
147
+
148
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
149
+ }
150
+
151
+ for (int j = 0; j < 4; j++) {
152
+ float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
153
+ int32x4_t vi = vcvtnq_s32_f32(v);
154
+ y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
155
+ y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
156
+ y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
157
+ y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
158
+ v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
159
+ vi = vcvtnq_s32_f32(v);
160
+ y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
161
+ y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
162
+ y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
163
+ y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
164
+
165
+ v = vmulq_n_f32(srcv[1][2 * j], id[1]);
166
+ vi = vcvtnq_s32_f32(v);
167
+ y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
168
+ y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
169
+ y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
170
+ y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
171
+ v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
172
+ vi = vcvtnq_s32_f32(v);
173
+ y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
174
+ y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
175
+ y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
176
+ y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
177
+
178
+ v = vmulq_n_f32(srcv[2][2 * j], id[2]);
179
+ vi = vcvtnq_s32_f32(v);
180
+ y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
181
+ y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
182
+ y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
183
+ y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
184
+ v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
185
+ vi = vcvtnq_s32_f32(v);
186
+ y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
187
+ y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
188
+ y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
189
+ y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
190
+
191
+ v = vmulq_n_f32(srcv[3][2 * j], id[3]);
192
+ vi = vcvtnq_s32_f32(v);
193
+ y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
194
+ y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
195
+ y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
196
+ y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
197
+ v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
198
+ vi = vcvtnq_s32_f32(v);
199
+ y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
200
+ y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
201
+ y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
202
+ y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
203
+ }
204
+ }
205
+
206
+ #else
207
+ // scalar
208
+ const int blck_size_interleave = 8;
209
+ float srcv[4][QK8_0];
210
+ float id[4];
211
+
212
+ for (int i = 0; i < nb; i++) {
213
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
214
+ float amax = 0.0f; // absolute max
215
+
216
+ for (int j = 0; j < QK8_0; j++) {
217
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
218
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
219
+ }
220
+
221
+ const float d = amax / ((1 << 7) - 1);
222
+ id[row_iter] = d ? 1.0f / d : 0.0f;
223
+
224
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
225
+ }
226
+
227
+ for (int j = 0; j < QK8_0 * 4; j++) {
228
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
229
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
230
+ src_offset += (j % blck_size_interleave);
231
+
232
+ float x0 = srcv[src_id][src_offset] * id[src_id];
233
+ y[i].qs[j] = roundf(x0);
234
+ }
235
+ }
236
+ #endif
237
+ }
238
+
239
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
240
+ const int qk = QK8_0;
241
+ const int nb = n / qk;
242
+ const int ncols_interleaved = 4;
243
+ const int blocklen = 4;
244
+
245
+ assert (n % qk == 0);
246
+ assert (nc % ncols_interleaved == 0);
247
+
248
+ UNUSED(s);
249
+ UNUSED(bs);
250
+ UNUSED(vx);
251
+ UNUSED(vy);
252
+ UNUSED(nr);
253
+ UNUSED(nc);
254
+ UNUSED(nb);
255
+ UNUSED(ncols_interleaved);
256
+ UNUSED(blocklen);
257
+
258
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
259
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
260
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
261
+
262
+ for (int c = 0; c < nc; c += ncols_interleaved) {
263
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
264
+ float32x4_t acc = vdupq_n_f32(0);
265
+ for (int b = 0; b < nb; b++) {
266
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
267
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
268
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
269
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
270
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
271
+
272
+ int8x16_t a0 = vld1q_s8(a_ptr->qs);
273
+ int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
274
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
275
+
276
+ int32x4_t ret = vdupq_n_s32(0);
277
+
278
+ ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
279
+ ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
280
+ ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
281
+ ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
282
+
283
+ ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
284
+ ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
285
+ ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
286
+ ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
287
+
288
+ acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
289
+ vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
290
+ a_ptr++;
291
+ b_ptr++;
292
+ }
293
+ vst1q_f32(s, acc);
294
+ s += ncols_interleaved;
295
+ }
296
+ return;
297
+ }
298
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
299
+ float sumf[4];
300
+ int sumi;
301
+
302
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
303
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
304
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
305
+
306
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
307
+ for (int l = 0; l < nb; l++) {
308
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
309
+ for (int j = 0; j < ncols_interleaved; j++) {
310
+ sumi = 0;
311
+ for (int i = 0; i < blocklen; ++i) {
312
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
313
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
314
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
315
+ }
316
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
317
+ }
318
+ }
319
+ }
320
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
321
+ }
322
+ }
323
+
324
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
325
+ const int qk = QK8_0;
326
+ const int nb = n / qk;
327
+ const int ncols_interleaved = 4;
328
+ const int blocklen = 8;
329
+
330
+ assert (n % qk == 0);
331
+ assert (nc % ncols_interleaved == 0);
332
+
333
+ UNUSED(s);
334
+ UNUSED(bs);
335
+ UNUSED(vx);
336
+ UNUSED(vy);
337
+ UNUSED(nr);
338
+ UNUSED(nc);
339
+ UNUSED(nb);
340
+ UNUSED(ncols_interleaved);
341
+ UNUSED(blocklen);
342
+
343
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
344
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
345
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
346
+
347
+ for (int c = 0; c < nc; c += ncols_interleaved) {
348
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
349
+ float32x4_t acc = vdupq_n_f32(0);
350
+ for (int b = 0; b < nb; b++) {
351
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
352
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
353
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
354
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
355
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
356
+
357
+ int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
358
+ int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
359
+ int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
360
+ int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
361
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
362
+
363
+ int32x4_t ret0 = vdupq_n_s32(0);
364
+ int32x4_t ret1 = vdupq_n_s32(0);
365
+
366
+ ret0 = vdotq_s32(ret0, b0 << 4, a0);
367
+ ret1 = vdotq_s32(ret1, b1 << 4, a0);
368
+ ret0 = vdotq_s32(ret0, b2 << 4, a1);
369
+ ret1 = vdotq_s32(ret1, b3 << 4, a1);
370
+
371
+ ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
372
+ ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
373
+ ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
374
+ ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
375
+
376
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
377
+
378
+ acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
379
+ vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
380
+ a_ptr++;
381
+ b_ptr++;
382
+ }
383
+ vst1q_f32(s, acc);
384
+ s += ncols_interleaved;
385
+ }
386
+ return;
387
+ }
388
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
389
+ float sumf[4];
390
+ int sumi;
391
+
392
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
393
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
394
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
395
+
396
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
397
+ for (int l = 0; l < nb; l++) {
398
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
399
+ for (int j = 0; j < ncols_interleaved; j++) {
400
+ sumi = 0;
401
+ for (int i = 0; i < blocklen; ++i) {
402
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
403
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
404
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
405
+ }
406
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
407
+ }
408
+ }
409
+ }
410
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
411
+ }
412
+ }
413
+
414
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
415
+ const int qk = QK8_0;
416
+ const int nb = n / qk;
417
+ const int ncols_interleaved = 8;
418
+ const int blocklen = 8;
419
+
420
+ assert (n % qk == 0);
421
+ assert (nc % ncols_interleaved == 0);
422
+
423
+ UNUSED(s);
424
+ UNUSED(bs);
425
+ UNUSED(vx);
426
+ UNUSED(vy);
427
+ UNUSED(nr);
428
+ UNUSED(nc);
429
+ UNUSED(nb);
430
+ UNUSED(ncols_interleaved);
431
+ UNUSED(blocklen);
432
+
433
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
434
+ #if defined(__ARM_FEATURE_SVE)
435
+ if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) {
436
+ const void * b_ptr = vx;
437
+ const void * a_ptr = vy;
438
+ float * res_ptr = s;
439
+
440
+ __asm__ __volatile__(
441
+ "ptrue p0.b\n"
442
+ "add %x[b_ptr], %x[b_ptr], #0x10\n"
443
+ "1:" // Column loop
444
+ "add x22, %x[a_ptr], #0x2\n"
445
+ "mov z31.b, #0x0\n"
446
+ "mov x21, %x[nb]\n"
447
+ "2:" // Block loop
448
+ "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
449
+ "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
450
+ "mov z28.s, #0x0\n"
451
+ "mov z27.s, #0x0\n"
452
+ "ld1rd { z26.d }, p0/Z, [x22]\n"
453
+ "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
454
+ "sub x20, x22, #0x2\n"
455
+ "sub x21, x21, #0x1\n"
456
+ "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
457
+ "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
458
+ "lsl z22.b, z30.b, #0x4\n"
459
+ "lsl z16.b, z29.b, #0x4\n"
460
+ "and z30.b, z30.b, #0xf0\n"
461
+ "and z29.b, z29.b, #0xf0\n"
462
+ "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
463
+ "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
464
+ "lsl z19.b, z25.b, #0x4\n"
465
+ "and z25.b, z25.b, #0xf0\n"
466
+ "ld1rh { z17.h }, p0/Z, [x20]\n"
467
+ "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
468
+ "sdot z28.s, z22.b, z26.b\n"
469
+ "sdot z27.s, z16.b, z26.b\n"
470
+ "lsl z16.b, z24.b, #0x4\n"
471
+ "add x22, x22, #0x22\n"
472
+ "and z24.b, z24.b, #0xf0\n"
473
+ "add %x[b_ptr], %x[b_ptr], #0x90\n"
474
+ "fcvt z17.s, p0/m, z17.h\n"
475
+ "fcvt z18.s, p0/m, z18.h\n"
476
+ "sdot z28.s, z19.b, z23.b\n"
477
+ "sdot z27.s, z16.b, z23.b\n"
478
+ "fmul z18.s, z18.s, z17.s\n"
479
+ "sdot z28.s, z30.b, z21.b\n"
480
+ "sdot z27.s, z29.b, z21.b\n"
481
+ "sdot z28.s, z25.b, z20.b\n"
482
+ "sdot z27.s, z24.b, z20.b\n"
483
+ "uzp1 z17.s, z28.s, z27.s\n"
484
+ "uzp2 z16.s, z28.s, z27.s\n"
485
+ "add z17.s, z17.s, z16.s\n"
486
+ "asr z17.s, z17.s, #0x4\n"
487
+ "scvtf z17.s, p0/m, z17.s\n"
488
+ "fmla z31.s, p0/M, z17.s, z18.s\n"
489
+ "cbnz x21, 2b\n"
490
+ "sub %x[nc], %x[nc], #0x8\n"
491
+ "st1w { z31.s }, p0, [%x[res_ptr]]\n"
492
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
493
+ "cbnz %x[nc], 1b\n"
494
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
495
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
496
+ : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
497
+ );
498
+ return;
499
+ }
500
+ #endif // #if defined(__ARM_FEATURE_SVE)
501
+
502
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
503
+ {
504
+ float sumf[8];
505
+ int sumi;
506
+
507
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
508
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
509
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
510
+
511
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
512
+ for (int l = 0; l < nb; l++) {
513
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
514
+ for (int j = 0; j < ncols_interleaved; j++) {
515
+ sumi = 0;
516
+ for (int i = 0; i < blocklen; ++i) {
517
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
518
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
519
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
520
+ }
521
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
522
+ }
523
+ }
524
+ }
525
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
526
+ }
527
+ }
528
+ }
529
+
530
+ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
531
+ const int qk = QK8_0;
532
+ const int nb = n / qk;
533
+ const int ncols_interleaved = 4;
534
+ const int blocklen = 4;
535
+
536
+ assert (n % qk == 0);
537
+ assert (nc % ncols_interleaved == 0);
538
+
539
+ UNUSED(s);
540
+ UNUSED(bs);
541
+ UNUSED(vx);
542
+ UNUSED(vy);
543
+ UNUSED(nr);
544
+ UNUSED(nc);
545
+ UNUSED(nb);
546
+ UNUSED(ncols_interleaved);
547
+ UNUSED(blocklen);
548
+
549
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
550
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
551
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
552
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
553
+ float * res_ptr = s;
554
+
555
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
556
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
557
+
558
+ float32x4_t sumf = vdupq_n_f32(0);
559
+ for (int l = 0; l < nb; l++) {
560
+ uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
561
+ uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
562
+ uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
563
+ uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
564
+
565
+ int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
566
+ int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
567
+ int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
568
+ int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
569
+ int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
570
+ int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
571
+ int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
572
+ int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
573
+
574
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
575
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
576
+
577
+ int32x4_t sumi = vdupq_n_s32(0);
578
+ sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
579
+ sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
580
+ sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
581
+ sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
582
+ sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
583
+ sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
584
+ sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
585
+ sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
586
+
587
+ float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
588
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
589
+ float32x4_t d = a_d * b_d;
590
+
591
+ sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
592
+ }
593
+
594
+ vst1q_f32(res_ptr + x * 4, sumf);
595
+ }
596
+ return;
597
+ }
598
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
599
+ {
600
+ float sumf[4];
601
+ int sumi;
602
+
603
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
604
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
605
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
606
+
607
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
608
+ for (int l = 0; l < nb; l++) {
609
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
610
+ for (int j = 0; j < ncols_interleaved; j++) {
611
+ sumi = 0;
612
+ for (int i = 0; i < blocklen; ++i) {
613
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
614
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
615
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
616
+ }
617
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
618
+ }
619
+ }
620
+ }
621
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
622
+ }
623
+ }
624
+ }
625
+
626
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
627
+ const int qk = QK8_0;
628
+ const int nb = n / qk;
629
+ const int ncols_interleaved = 4;
630
+ const int blocklen = 4;
631
+
632
+ assert (n % qk == 0);
633
+ assert (nr % 4 == 0);
634
+ assert (nc % ncols_interleaved == 0);
635
+
636
+ UNUSED(s);
637
+ UNUSED(bs);
638
+ UNUSED(vx);
639
+ UNUSED(vy);
640
+ UNUSED(nr);
641
+ UNUSED(nc);
642
+ UNUSED(nb);
643
+ UNUSED(ncols_interleaved);
644
+ UNUSED(blocklen);
645
+
646
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
647
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
648
+ const void * b_ptr = vx;
649
+ const void * a_ptr = vy;
650
+ float * res_ptr = s;
651
+ size_t res_stride = bs * sizeof(float);
652
+
653
+ __asm__ __volatile__(
654
+ "mov x10, %x[nr]\n"
655
+ "mov x9, #0x88\n"
656
+ "cmp x10, #0x10\n"
657
+ "mul x9, %x[nb], x9\n"
658
+ "blt 4f\n"
659
+ "1:" // Row loop
660
+ "add x28, %x[b_ptr], #0x8\n"
661
+ "mov x27, %x[nc]\n"
662
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
663
+ "2:" // Column loop
664
+ "add x25, %x[a_ptr], #0x8\n"
665
+ "movi v15.16b, #0x0\n"
666
+ "movi v19.16b, #0x0\n"
667
+ "mov x24, %x[nb]\n"
668
+ "add x23, x25, x9\n"
669
+ "movi v18.16b, #0x0\n"
670
+ "movi v14.16b, #0x0\n"
671
+ "add x22, x23, x9\n"
672
+ "movi v11.16b, #0x0\n"
673
+ "movi v13.16b, #0x0\n"
674
+ "add x21, x22, x9\n"
675
+ "movi v23.16b, #0x0\n"
676
+ "movi v16.16b, #0x0\n"
677
+ "movi v25.16b, #0x0\n"
678
+ "movi v7.16b, #0x0\n"
679
+ "movi v0.16b, #0x0\n"
680
+ "movi v4.16b, #0x0\n"
681
+ "movi v5.16b, #0x0\n"
682
+ "movi v21.16b, #0x0\n"
683
+ "movi v8.16b, #0x0\n"
684
+ "movi v1.16b, #0x0\n"
685
+ "3:" // Block loop
686
+ "ldr q3, [x28, #0x0]\n"
687
+ "ldr q31, [x25, #0x0]\n"
688
+ "movi v28.16b, #0x4\n"
689
+ "movi v10.4s, #0x0\n"
690
+ "ldr q22, [x28, #0x10]\n"
691
+ "ldr q6, [x25, #0x10]\n"
692
+ "movi v29.4s, #0x0\n"
693
+ "movi v9.4s, #0x0\n"
694
+ "ldr q27, [x28, #0x20]\n"
695
+ "ldr q30, [x28, #0x30]\n"
696
+ "movi v20.4s, #0x0\n"
697
+ "movi v24.16b, #0xf0\n"
698
+ "ldr d2, [x25, #-0x8]\n"
699
+ "ldr d26, [x23, #-0x8]\n"
700
+ "sshl v12.16b, v3.16b, v28.16b\n"
701
+ "sub x20, x28, #0x8\n"
702
+ "ldr d17, [x20, #0x0]\n"
703
+ "and v3.16b, v3.16b, v24.16b\n"
704
+ "subs x24, x24, #0x1\n"
705
+ "add x28, x28, #0x48\n"
706
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
707
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
708
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
709
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
710
+ "sshl v31.16b, v22.16b, v28.16b\n"
711
+ "and v22.16b, v22.16b, v24.16b\n"
712
+ "fcvtl v17.4s, v17.4h\n"
713
+ "fcvtl v2.4s, v2.4h\n"
714
+ "fcvtl v26.4s, v26.4h\n"
715
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
716
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
717
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
718
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
719
+ "sshl v6.16b, v27.16b, v28.16b\n"
720
+ "sshl v28.16b, v30.16b, v28.16b\n"
721
+ "and v27.16b, v27.16b, v24.16b\n"
722
+ "and v30.16b, v30.16b, v24.16b\n"
723
+ "ldr q24, [x25, #0x20]\n"
724
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
725
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
726
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
727
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
728
+ "ldr q24, [x25, #0x30]\n"
729
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
730
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
731
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
732
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
733
+ "ldr q24, [x25, #0x40]\n"
734
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
735
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
736
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
737
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
738
+ "ldr q24, [x25, #0x50]\n"
739
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
740
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
741
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
742
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
743
+ "ldr q24, [x25, #0x60]\n"
744
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
745
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
746
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
747
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
748
+ "ldr q24, [x25, #0x70]\n"
749
+ "add x25, x25, #0x88\n"
750
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
751
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
752
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
753
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
754
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
755
+ "scvtf v10.4s, v10.4s, #0x4\n"
756
+ "scvtf v29.4s, v29.4s, #0x4\n"
757
+ "scvtf v9.4s, v9.4s, #0x4\n"
758
+ "scvtf v20.4s, v20.4s, #0x4\n"
759
+ "fmla v15.4s, v10.4s, v24.4s\n"
760
+ "ldr q24, [x23, #0x0]\n"
761
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
762
+ "fmla v19.4s, v29.4s, v10.4s\n"
763
+ "ldr q10, [x23, #0x10]\n"
764
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
765
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
766
+ "fmla v18.4s, v9.4s, v29.4s\n"
767
+ "movi v9.4s, #0x0\n"
768
+ "movi v29.4s, #0x0\n"
769
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
770
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
771
+ "fmla v14.4s, v20.4s, v2.4s\n"
772
+ "movi v20.4s, #0x0\n"
773
+ "movi v2.4s, #0x0\n"
774
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
775
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
776
+ "ldr q24, [x23, #0x20]\n"
777
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
778
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
779
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
780
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
781
+ "ldr q10, [x23, #0x30]\n"
782
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
783
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
784
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
785
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
786
+ "ldr q24, [x23, #0x40]\n"
787
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
788
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
789
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
790
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
791
+ "ldr q10, [x23, #0x50]\n"
792
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
793
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
794
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
795
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
796
+ "ldr q24, [x23, #0x60]\n"
797
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
798
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
799
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
800
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
801
+ "ldr q10, [x23, #0x70]\n"
802
+ "add x23, x23, #0x88\n"
803
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
804
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
805
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
806
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
807
+ "ldr q24, [x22, #0x0]\n"
808
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
809
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
810
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
811
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
812
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
813
+ "scvtf v9.4s, v9.4s, #0x4\n"
814
+ "scvtf v29.4s, v29.4s, #0x4\n"
815
+ "scvtf v20.4s, v20.4s, #0x4\n"
816
+ "scvtf v2.4s, v2.4s, #0x4\n"
817
+ "fmla v11.4s, v9.4s, v10.4s\n"
818
+ "ldr q9, [x22, #0x10]\n"
819
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
820
+ "fmla v13.4s, v29.4s, v10.4s\n"
821
+ "ldr d29, [x22, #-0x8]\n"
822
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
823
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
824
+ "fcvtl v29.4s, v29.4h\n"
825
+ "fmla v23.4s, v20.4s, v10.4s\n"
826
+ "movi v20.4s, #0x0\n"
827
+ "movi v10.4s, #0x0\n"
828
+ "fmla v16.4s, v2.4s, v26.4s\n"
829
+ "movi v26.4s, #0x0\n"
830
+ "movi v2.4s, #0x0\n"
831
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
832
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
833
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
834
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
835
+ "ldr q24, [x22, #0x20]\n"
836
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
837
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
838
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
839
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
840
+ "ldr q9, [x22, #0x30]\n"
841
+ ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
842
+ ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
843
+ ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
844
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
845
+ "ldr q24, [x22, #0x40]\n"
846
+ ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
847
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
848
+ ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
849
+ ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
850
+ "ldr q9, [x22, #0x50]\n"
851
+ ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
852
+ ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
853
+ ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
854
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
855
+ "ldr q24, [x22, #0x60]\n"
856
+ ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
857
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
858
+ ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
859
+ ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
860
+ "ldr q9, [x22, #0x70]\n"
861
+ "add x22, x22, #0x88\n"
862
+ ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
863
+ ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
864
+ ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
865
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
866
+ "ldr q24, [x21, #0x0]\n"
867
+ ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
868
+ ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
869
+ ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
870
+ ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
871
+ "fmul v9.4s, v17.4s, v29.s[0]\n"
872
+ "scvtf v20.4s, v20.4s, #0x4\n"
873
+ "scvtf v10.4s, v10.4s, #0x4\n"
874
+ "scvtf v26.4s, v26.4s, #0x4\n"
875
+ "scvtf v2.4s, v2.4s, #0x4\n"
876
+ "fmla v25.4s, v20.4s, v9.4s\n"
877
+ "ldr q9, [x21, #0x10]\n"
878
+ "fmul v20.4s, v17.4s, v29.s[1]\n"
879
+ "fmla v7.4s, v10.4s, v20.4s\n"
880
+ "ldr d20, [x21, #-0x8]\n"
881
+ "fmul v10.4s, v17.4s, v29.s[2]\n"
882
+ "fmul v29.4s, v17.4s, v29.s[3]\n"
883
+ "fcvtl v20.4s, v20.4h\n"
884
+ "fmla v0.4s, v26.4s, v10.4s\n"
885
+ "movi v26.4s, #0x0\n"
886
+ "movi v10.4s, #0x0\n"
887
+ "fmla v4.4s, v2.4s, v29.4s\n"
888
+ "movi v2.4s, #0x0\n"
889
+ "movi v29.4s, #0x0\n"
890
+ ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
891
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
892
+ ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
893
+ ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
894
+ "ldr q12, [x21, #0x20]\n"
895
+ "fmul v24.4s, v17.4s, v20.s[0]\n"
896
+ ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
897
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
898
+ ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
899
+ ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
900
+ "ldr q9, [x21, #0x30]\n"
901
+ "fmul v31.4s, v17.4s, v20.s[1]\n"
902
+ ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
903
+ ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
904
+ ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
905
+ ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
906
+ "ldr q12, [x21, #0x40]\n"
907
+ "fmul v6.4s, v17.4s, v20.s[2]\n"
908
+ "fmul v20.4s, v17.4s, v20.s[3]\n"
909
+ ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
910
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
911
+ ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
912
+ ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
913
+ "ldr q9, [x21, #0x50]\n"
914
+ ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
915
+ ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
916
+ ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
917
+ ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
918
+ "ldr q12, [x21, #0x60]\n"
919
+ ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
920
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
921
+ ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
922
+ ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
923
+ "ldr q17, [x21, #0x70]\n"
924
+ "add x21, x21, #0x88\n"
925
+ ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
926
+ ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
927
+ ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
928
+ ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
929
+ ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
930
+ ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
931
+ ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
932
+ ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
933
+ "scvtf v26.4s, v26.4s, #0x4\n"
934
+ "scvtf v10.4s, v10.4s, #0x4\n"
935
+ "fmla v5.4s, v26.4s, v24.4s\n"
936
+ "scvtf v2.4s, v2.4s, #0x4\n"
937
+ "scvtf v29.4s, v29.4s, #0x4\n"
938
+ "fmla v21.4s, v10.4s, v31.4s\n"
939
+ "fmla v8.4s, v2.4s, v6.4s\n"
940
+ "fmla v1.4s, v29.4s, v20.4s\n"
941
+ "bgt 3b\n"
942
+ "mov x20, %x[res_ptr]\n"
943
+ "subs x27, x27, #0x4\n"
944
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
945
+ "str q15, [x20, #0x0]\n"
946
+ "add x20, x20, %x[res_stride]\n"
947
+ "str q19, [x20, #0x0]\n"
948
+ "add x20, x20, %x[res_stride]\n"
949
+ "str q18, [x20, #0x0]\n"
950
+ "add x20, x20, %x[res_stride]\n"
951
+ "str q14, [x20, #0x0]\n"
952
+ "add x20, x20, %x[res_stride]\n"
953
+ "str q11, [x20, #0x0]\n"
954
+ "add x20, x20, %x[res_stride]\n"
955
+ "str q13, [x20, #0x0]\n"
956
+ "add x20, x20, %x[res_stride]\n"
957
+ "str q23, [x20, #0x0]\n"
958
+ "add x20, x20, %x[res_stride]\n"
959
+ "str q16, [x20, #0x0]\n"
960
+ "add x20, x20, %x[res_stride]\n"
961
+ "str q25, [x20, #0x0]\n"
962
+ "add x20, x20, %x[res_stride]\n"
963
+ "str q7, [x20, #0x0]\n"
964
+ "add x20, x20, %x[res_stride]\n"
965
+ "str q0, [x20, #0x0]\n"
966
+ "add x20, x20, %x[res_stride]\n"
967
+ "str q4, [x20, #0x0]\n"
968
+ "add x20, x20, %x[res_stride]\n"
969
+ "str q5, [x20, #0x0]\n"
970
+ "add x20, x20, %x[res_stride]\n"
971
+ "str q21, [x20, #0x0]\n"
972
+ "add x20, x20, %x[res_stride]\n"
973
+ "str q8, [x20, #0x0]\n"
974
+ "add x20, x20, %x[res_stride]\n"
975
+ "str q1, [x20, #0x0]\n"
976
+ "bne 2b\n"
977
+ "mov x20, #0x4\n"
978
+ "sub x10, x10, #0x10\n"
979
+ "cmp x10, #0x10\n"
980
+ "mov %x[res_ptr], x26\n"
981
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
982
+ "bge 1b\n"
983
+ "4:" // Row loop skip
984
+ "cbz x10, 9f\n"
985
+ "5:" // Row tail: Row loop
986
+ "add x24, %x[b_ptr], #0x8\n"
987
+ "mov x23, %x[nc]\n"
988
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
989
+ "6:" // Row tail: Column loop
990
+ "movi v15.16b, #0x0\n"
991
+ "movi v19.16b, #0x0\n"
992
+ "add x25, %x[a_ptr], #0x8\n"
993
+ "mov x21, %x[nb]\n"
994
+ "movi v18.16b, #0x0\n"
995
+ "movi v14.16b, #0x0\n"
996
+ "7:" // Row tail: Block loop
997
+ "ldr q7, [x24, #0x0]\n"
998
+ "ldr q5, [x25, #0x0]\n"
999
+ "movi v9.16b, #0x4\n"
1000
+ "movi v4.4s, #0x0\n"
1001
+ "ldr q3, [x24, #0x10]\n"
1002
+ "ldr q2, [x25, #0x10]\n"
1003
+ "movi v1.4s, #0x0\n"
1004
+ "movi v0.4s, #0x0\n"
1005
+ "ldr q13, [x24, #0x20]\n"
1006
+ "ldr q31, [x25, #0x20]\n"
1007
+ "movi v30.4s, #0x0\n"
1008
+ "movi v29.16b, #0xf0\n"
1009
+ "ldr q28, [x24, #0x30]\n"
1010
+ "ldr q27, [x25, #0x30]\n"
1011
+ "sshl v20.16b, v7.16b, v9.16b\n"
1012
+ "sub x20, x24, #0x8\n"
1013
+ "ldr q26, [x25, #0x40]\n"
1014
+ "ldr q25, [x25, #0x50]\n"
1015
+ "sshl v17.16b, v3.16b, v9.16b\n"
1016
+ "and v7.16b, v7.16b, v29.16b\n"
1017
+ "ldr q24, [x25, #0x60]\n"
1018
+ "ldr q16, [x25, #0x70]\n"
1019
+ "sshl v22.16b, v13.16b, v9.16b\n"
1020
+ "and v3.16b, v3.16b, v29.16b\n"
1021
+ "ldr d21, [x20, #0x0]\n"
1022
+ "ldr d12, [x25, #-0x8]\n"
1023
+ ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
1024
+ ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
1025
+ ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
1026
+ ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
1027
+ "sshl v9.16b, v28.16b, v9.16b\n"
1028
+ "subs x21, x21, #0x1\n"
1029
+ "and v13.16b, v13.16b, v29.16b\n"
1030
+ "and v28.16b, v28.16b, v29.16b\n"
1031
+ "add x25, x25, #0x88\n"
1032
+ "add x24, x24, #0x48\n"
1033
+ "fcvtl v21.4s, v21.4h\n"
1034
+ "fcvtl v12.4s, v12.4h\n"
1035
+ ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
1036
+ ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
1037
+ ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
1038
+ ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
1039
+ "fmul v11.4s, v21.4s, v12.s[0]\n"
1040
+ "fmul v23.4s, v21.4s, v12.s[1]\n"
1041
+ "fmul v17.4s, v21.4s, v12.s[2]\n"
1042
+ ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
1043
+ "fmul v6.4s, v21.4s, v12.s[3]\n"
1044
+ ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
1045
+ ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
1046
+ ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
1047
+ ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
1048
+ ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
1049
+ ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
1050
+ ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
1051
+ ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
1052
+ ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
1053
+ ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
1054
+ ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
1055
+ ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
1056
+ ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
1057
+ ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
1058
+ ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
1059
+ ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
1060
+ ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
1061
+ ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
1062
+ ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
1063
+ ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
1064
+ ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
1065
+ ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
1066
+ ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
1067
+ "scvtf v4.4s, v4.4s, #0x4\n"
1068
+ "scvtf v1.4s, v1.4s, #0x4\n"
1069
+ "scvtf v0.4s, v0.4s, #0x4\n"
1070
+ "fmla v15.4s, v4.4s, v11.4s\n"
1071
+ "scvtf v30.4s, v30.4s, #0x4\n"
1072
+ "fmla v19.4s, v1.4s, v23.4s\n"
1073
+ "fmla v18.4s, v0.4s, v17.4s\n"
1074
+ "fmla v14.4s, v30.4s, v6.4s\n"
1075
+ "bgt 7b\n"
1076
+ "mov x20, %x[res_ptr]\n"
1077
+ "cmp x10, #0x1\n"
1078
+ "str q15, [x20, #0x0]\n"
1079
+ "add x20, x20, %x[res_stride]\n"
1080
+ "ble 8f\n"
1081
+ "cmp x10, #0x2\n"
1082
+ "str q19, [x20, #0x0]\n"
1083
+ "add x20, x20, %x[res_stride]\n"
1084
+ "ble 8f\n"
1085
+ "cmp x10, #0x3\n"
1086
+ "str q18, [x20, #0x0]\n"
1087
+ "add x20, x20, %x[res_stride]\n"
1088
+ "ble 8f\n"
1089
+ "str q14, [x20, #0x0]\n"
1090
+ "8:" // Row tail: Accumulator store skip
1091
+ "subs x23, x23, #0x4\n"
1092
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
1093
+ "bne 6b\n"
1094
+ "subs x10, x10, #0x4\n"
1095
+ "add %x[a_ptr], %x[a_ptr], x9\n"
1096
+ "mov %x[res_ptr], x22\n"
1097
+ "bgt 5b\n"
1098
+ "9:" // Row tail: Row loop skip
1099
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
1100
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
1101
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
1102
+ );
1103
+ return;
1104
+ }
1105
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1106
+ {
1107
+ float sumf[4][4];
1108
+ int sumi;
1109
+
1110
+ for (int y = 0; y < nr / 4; y++) {
1111
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1112
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1113
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1114
+ for (int m = 0; m < 4; m++) {
1115
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1116
+ }
1117
+ for (int l = 0; l < nb; l++) {
1118
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1119
+ for (int m = 0; m < 4; m++) {
1120
+ for (int j = 0; j < ncols_interleaved; j++) {
1121
+ sumi = 0;
1122
+ for (int i = 0; i < blocklen; ++i) {
1123
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1124
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1125
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1126
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1127
+ }
1128
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
1129
+ }
1130
+ }
1131
+ }
1132
+ }
1133
+ for (int m = 0; m < 4; m++) {
1134
+ for (int j = 0; j < ncols_interleaved; j++)
1135
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1136
+ }
1137
+ }
1138
+ }
1139
+ }
1140
+ }
1141
+
1142
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1143
+ const int qk = QK8_0;
1144
+ const int nb = n / qk;
1145
+ const int ncols_interleaved = 4;
1146
+ const int blocklen = 8;
1147
+
1148
+ assert (n % qk == 0);
1149
+ assert (nr % 4 == 0);
1150
+ assert (nc % ncols_interleaved == 0);
1151
+
1152
+ UNUSED(s);
1153
+ UNUSED(bs);
1154
+ UNUSED(vx);
1155
+ UNUSED(vy);
1156
+ UNUSED(nr);
1157
+ UNUSED(nc);
1158
+ UNUSED(nb);
1159
+ UNUSED(ncols_interleaved);
1160
+ UNUSED(blocklen);
1161
+
1162
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
1163
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
1164
+ const void * b_ptr = vx;
1165
+ const void * a_ptr = vy;
1166
+ float * res_ptr = s;
1167
+ size_t res_stride = bs * sizeof(float);
1168
+
1169
+ __asm__ __volatile__(
1170
+ "mov x10, %x[nr]\n"
1171
+ "mov x9, #0x88\n"
1172
+ "cmp x10, #0x10\n"
1173
+ "mul x9, %x[nb], x9\n"
1174
+ "blt 4f\n"
1175
+ "1:" // Row loop
1176
+ "add x28, %x[b_ptr], #0x8\n"
1177
+ "mov x27, %x[nc]\n"
1178
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
1179
+ "2:" // Column loop
1180
+ "add x25, %x[a_ptr], #0x8\n"
1181
+ "movi v2.16b, #0x0\n"
1182
+ "movi v10.16b, #0x0\n"
1183
+ "mov x24, %x[nb]\n"
1184
+ "add x23, x25, x9\n"
1185
+ "movi v12.16b, #0x0\n"
1186
+ "movi v28.16b, #0x0\n"
1187
+ "add x22, x23, x9\n"
1188
+ "movi v11.16b, #0x0\n"
1189
+ "movi v13.16b, #0x0\n"
1190
+ "add x21, x22, x9\n"
1191
+ "movi v22.16b, #0x0\n"
1192
+ "movi v23.16b, #0x0\n"
1193
+ "movi v25.16b, #0x0\n"
1194
+ "movi v5.16b, #0x0\n"
1195
+ "movi v7.16b, #0x0\n"
1196
+ "movi v4.16b, #0x0\n"
1197
+ "movi v6.16b, #0x0\n"
1198
+ "movi v30.16b, #0x0\n"
1199
+ "movi v24.16b, #0x0\n"
1200
+ "movi v14.16b, #0x0\n"
1201
+ "3:" // Block loop
1202
+ "ldr q21, [x28, #0x0]\n"
1203
+ "ldr q16, [x28, #0x10]\n"
1204
+ "movi v1.16b, #0x4\n"
1205
+ "movi v19.4s, #0x0\n"
1206
+ "ldr q27, [x25, #0x0]\n"
1207
+ "ldr q15, [x25, #0x10]\n"
1208
+ "movi v26.4s, #0x0\n"
1209
+ "movi v18.4s, #0x0\n"
1210
+ "ldr q29, [x28, #0x20]\n"
1211
+ "ldr q3, [x28, #0x30]\n"
1212
+ "movi v17.4s, #0x0\n"
1213
+ "movi v0.16b, #0xf0\n"
1214
+ "ldr d20, [x25, #-0x8]\n"
1215
+ "ldr d9, [x23, #-0x8]\n"
1216
+ "sshl v8.16b, v21.16b, v1.16b\n"
1217
+ "sshl v31.16b, v16.16b, v1.16b\n"
1218
+ "and v21.16b, v21.16b, v0.16b\n"
1219
+ "and v16.16b, v16.16b, v0.16b\n"
1220
+ "sub x20, x28, #0x8\n"
1221
+ "subs x24, x24, #0x1\n"
1222
+ "add x28, x28, #0x48\n"
1223
+ ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
1224
+ ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
1225
+ "ldr q27, [x25, #0x20]\n"
1226
+ ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
1227
+ ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
1228
+ "sshl v15.16b, v29.16b, v1.16b\n"
1229
+ "sshl v1.16b, v3.16b, v1.16b\n"
1230
+ "and v29.16b, v29.16b, v0.16b\n"
1231
+ "and v3.16b, v3.16b, v0.16b\n"
1232
+ "ldr q0, [x25, #0x30]\n"
1233
+ "fcvtl v20.4s, v20.4h\n"
1234
+ ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
1235
+ "fcvtl v9.4s, v9.4h\n"
1236
+ ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
1237
+ "ldr q27, [x25, #0x40]\n"
1238
+ ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
1239
+ ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
1240
+ "ldr q0, [x25, #0x50]\n"
1241
+ ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
1242
+ ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
1243
+ "ldr q27, [x25, #0x60]\n"
1244
+ ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
1245
+ ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
1246
+ "ldr q0, [x25, #0x70]\n"
1247
+ "add x25, x25, #0x88\n"
1248
+ ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
1249
+ ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
1250
+ "ldr d27, [x20, #0x0]\n"
1251
+ ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
1252
+ ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
1253
+ "fcvtl v27.4s, v27.4h\n"
1254
+ "uzp1 v0.2d, v19.2d, v26.2d\n"
1255
+ "uzp2 v26.2d, v19.2d, v26.2d\n"
1256
+ "fmul v19.4s, v27.4s, v20.s[0]\n"
1257
+ "scvtf v0.4s, v0.4s, #0x4\n"
1258
+ "scvtf v26.4s, v26.4s, #0x4\n"
1259
+ "fmla v2.4s, v0.4s, v19.4s\n"
1260
+ "ldr q19, [x23, #0x0]\n"
1261
+ "uzp1 v0.2d, v18.2d, v17.2d\n"
1262
+ "uzp2 v18.2d, v18.2d, v17.2d\n"
1263
+ "fmul v17.4s, v27.4s, v20.s[1]\n"
1264
+ "scvtf v0.4s, v0.4s, #0x4\n"
1265
+ "scvtf v18.4s, v18.4s, #0x4\n"
1266
+ "fmla v10.4s, v26.4s, v17.4s\n"
1267
+ "ldr q17, [x23, #0x10]\n"
1268
+ "fmul v26.4s, v27.4s, v20.s[2]\n"
1269
+ "fmul v20.4s, v27.4s, v20.s[3]\n"
1270
+ "fmla v12.4s, v0.4s, v26.4s\n"
1271
+ "ldr d0, [x22, #-0x8]\n"
1272
+ "ldr d26, [x21, #-0x8]\n"
1273
+ "fcvtl v0.4s, v0.4h\n"
1274
+ "fmla v28.4s, v18.4s, v20.4s\n"
1275
+ "movi v20.4s, #0x0\n"
1276
+ "movi v18.4s, #0x0\n"
1277
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
1278
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
1279
+ "ldr q19, [x23, #0x20]\n"
1280
+ "fcvtl v26.4s, v26.4h\n"
1281
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
1282
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
1283
+ "ldr q19, [x23, #0x40]\n"
1284
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
1285
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
1286
+ "ldr q19, [x23, #0x60]\n"
1287
+ ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
1288
+ ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
1289
+ "uzp1 v19.2d, v20.2d, v18.2d\n"
1290
+ "scvtf v19.4s, v19.4s, #0x4\n"
1291
+ "uzp2 v20.2d, v20.2d, v18.2d\n"
1292
+ "fmul v18.4s, v27.4s, v9.s[0]\n"
1293
+ "scvtf v20.4s, v20.4s, #0x4\n"
1294
+ "fmla v11.4s, v19.4s, v18.4s\n"
1295
+ "ldr q18, [x22, #0x0]\n"
1296
+ "fmul v19.4s, v27.4s, v9.s[1]\n"
1297
+ "fmla v13.4s, v20.4s, v19.4s\n"
1298
+ "movi v19.4s, #0x0\n"
1299
+ "movi v20.4s, #0x0\n"
1300
+ ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
1301
+ ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
1302
+ "ldr q17, [x23, #0x30]\n"
1303
+ ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
1304
+ ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
1305
+ "ldr q17, [x23, #0x50]\n"
1306
+ ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
1307
+ ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
1308
+ "ldr q17, [x23, #0x70]\n"
1309
+ "add x23, x23, #0x88\n"
1310
+ ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
1311
+ ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
1312
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
1313
+ "scvtf v17.4s, v17.4s, #0x4\n"
1314
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
1315
+ "fmul v19.4s, v27.4s, v9.s[2]\n"
1316
+ "fmul v9.4s, v27.4s, v9.s[3]\n"
1317
+ "scvtf v20.4s, v20.4s, #0x4\n"
1318
+ "fmla v22.4s, v17.4s, v19.4s\n"
1319
+ "ldr q17, [x22, #0x10]\n"
1320
+ "movi v19.4s, #0x0\n"
1321
+ ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
1322
+ "fmla v23.4s, v20.4s, v9.4s\n"
1323
+ "movi v20.4s, #0x0\n"
1324
+ "movi v9.4s, #0x0\n"
1325
+ ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
1326
+ "ldr q18, [x22, #0x20]\n"
1327
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
1328
+ ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
1329
+ ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
1330
+ "ldr q18, [x22, #0x40]\n"
1331
+ ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
1332
+ ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
1333
+ "ldr q18, [x22, #0x60]\n"
1334
+ ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
1335
+ ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
1336
+ "movi v18.4s, #0x0\n"
1337
+ ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
1338
+ "ldr q17, [x22, #0x30]\n"
1339
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
1340
+ ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
1341
+ "ldr q17, [x22, #0x50]\n"
1342
+ ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
1343
+ ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
1344
+ "ldr q17, [x22, #0x70]\n"
1345
+ "add x22, x22, #0x88\n"
1346
+ ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
1347
+ ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
1348
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
1349
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
1350
+ "fmul v19.4s, v27.4s, v0.s[0]\n"
1351
+ "scvtf v17.4s, v17.4s, #0x4\n"
1352
+ "scvtf v20.4s, v20.4s, #0x4\n"
1353
+ "fmla v25.4s, v17.4s, v19.4s\n"
1354
+ "ldr q19, [x21, #0x0]\n"
1355
+ "fmul v17.4s, v27.4s, v0.s[1]\n"
1356
+ "fmla v5.4s, v20.4s, v17.4s\n"
1357
+ "ldr q17, [x21, #0x10]\n"
1358
+ "uzp1 v20.2d, v9.2d, v18.2d\n"
1359
+ "uzp2 v9.2d, v9.2d, v18.2d\n"
1360
+ "fmul v18.4s, v27.4s, v0.s[2]\n"
1361
+ "fmul v0.4s, v27.4s, v0.s[3]\n"
1362
+ "scvtf v20.4s, v20.4s, #0x4\n"
1363
+ "scvtf v9.4s, v9.4s, #0x4\n"
1364
+ "fmla v7.4s, v20.4s, v18.4s\n"
1365
+ "movi v20.4s, #0x0\n"
1366
+ "movi v18.4s, #0x0\n"
1367
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
1368
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
1369
+ "ldr q19, [x21, #0x20]\n"
1370
+ "fmla v4.4s, v9.4s, v0.4s\n"
1371
+ "movi v9.4s, #0x0\n"
1372
+ "movi v0.4s, #0x0\n"
1373
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
1374
+ "fmul v8.4s, v27.4s, v26.s[0]\n"
1375
+ ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
1376
+ "ldr q17, [x21, #0x30]\n"
1377
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
1378
+ "fmul v31.4s, v27.4s, v26.s[1]\n"
1379
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
1380
+ "ldr q19, [x21, #0x40]\n"
1381
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
1382
+ "fmul v15.4s, v27.4s, v26.s[2]\n"
1383
+ "fmul v27.4s, v27.4s, v26.s[3]\n"
1384
+ ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
1385
+ "ldr q1, [x21, #0x50]\n"
1386
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
1387
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
1388
+ "ldr q26, [x21, #0x60]\n"
1389
+ ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
1390
+ ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
1391
+ "ldr q21, [x21, #0x70]\n"
1392
+ "add x21, x21, #0x88\n"
1393
+ ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
1394
+ ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
1395
+ ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
1396
+ ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
1397
+ "uzp1 v29.2d, v20.2d, v18.2d\n"
1398
+ "uzp2 v21.2d, v20.2d, v18.2d\n"
1399
+ "scvtf v29.4s, v29.4s, #0x4\n"
1400
+ "uzp1 v18.2d, v9.2d, v0.2d\n"
1401
+ "uzp2 v16.2d, v9.2d, v0.2d\n"
1402
+ "scvtf v21.4s, v21.4s, #0x4\n"
1403
+ "fmla v6.4s, v29.4s, v8.4s\n"
1404
+ "scvtf v18.4s, v18.4s, #0x4\n"
1405
+ "scvtf v16.4s, v16.4s, #0x4\n"
1406
+ "fmla v30.4s, v21.4s, v31.4s\n"
1407
+ "fmla v24.4s, v18.4s, v15.4s\n"
1408
+ "fmla v14.4s, v16.4s, v27.4s\n"
1409
+ "bgt 3b\n"
1410
+ "mov x20, %x[res_ptr]\n"
1411
+ "subs x27, x27, #0x4\n"
1412
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
1413
+ "str q2, [x20, #0x0]\n"
1414
+ "add x20, x20, %x[res_stride]\n"
1415
+ "str q10, [x20, #0x0]\n"
1416
+ "add x20, x20, %x[res_stride]\n"
1417
+ "str q12, [x20, #0x0]\n"
1418
+ "add x20, x20, %x[res_stride]\n"
1419
+ "str q28, [x20, #0x0]\n"
1420
+ "add x20, x20, %x[res_stride]\n"
1421
+ "str q11, [x20, #0x0]\n"
1422
+ "add x20, x20, %x[res_stride]\n"
1423
+ "str q13, [x20, #0x0]\n"
1424
+ "add x20, x20, %x[res_stride]\n"
1425
+ "str q22, [x20, #0x0]\n"
1426
+ "add x20, x20, %x[res_stride]\n"
1427
+ "str q23, [x20, #0x0]\n"
1428
+ "add x20, x20, %x[res_stride]\n"
1429
+ "str q25, [x20, #0x0]\n"
1430
+ "add x20, x20, %x[res_stride]\n"
1431
+ "str q5, [x20, #0x0]\n"
1432
+ "add x20, x20, %x[res_stride]\n"
1433
+ "str q7, [x20, #0x0]\n"
1434
+ "add x20, x20, %x[res_stride]\n"
1435
+ "str q4, [x20, #0x0]\n"
1436
+ "add x20, x20, %x[res_stride]\n"
1437
+ "str q6, [x20, #0x0]\n"
1438
+ "add x20, x20, %x[res_stride]\n"
1439
+ "str q30, [x20, #0x0]\n"
1440
+ "add x20, x20, %x[res_stride]\n"
1441
+ "str q24, [x20, #0x0]\n"
1442
+ "add x20, x20, %x[res_stride]\n"
1443
+ "str q14, [x20, #0x0]\n"
1444
+ "bne 2b\n"
1445
+ "mov x20, #0x4\n"
1446
+ "sub x10, x10, #0x10\n"
1447
+ "cmp x10, #0x10\n"
1448
+ "mov %x[res_ptr], x26\n"
1449
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
1450
+ "bge 1b\n"
1451
+ "4:" // Row loop skip
1452
+ "cbz x10, 9f\n"
1453
+ "5:" // Row tail: Row loop
1454
+ "add x24, %x[b_ptr], #0x8\n"
1455
+ "mov x23, %x[nc]\n"
1456
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
1457
+ "6:" // Row tail: Column loop
1458
+ "movi v2.16b, #0x0\n"
1459
+ "movi v10.16b, #0x0\n"
1460
+ "add x25, %x[a_ptr], #0x8\n"
1461
+ "mov x21, %x[nb]\n"
1462
+ "movi v12.16b, #0x0\n"
1463
+ "movi v28.16b, #0x0\n"
1464
+ "7:" // Row tail: Block loop
1465
+ "ldr q6, [x24, #0x0]\n"
1466
+ "ldr q5, [x24, #0x10]\n"
1467
+ "movi v17.16b, #0x4\n"
1468
+ "movi v8.4s, #0x0\n"
1469
+ "ldr q4, [x25, #0x0]\n"
1470
+ "ldr q13, [x25, #0x10]\n"
1471
+ "movi v27.4s, #0x0\n"
1472
+ "movi v0.4s, #0x0\n"
1473
+ "ldr q31, [x24, #0x20]\n"
1474
+ "ldr q14, [x24, #0x30]\n"
1475
+ "movi v29.4s, #0x0\n"
1476
+ "movi v22.16b, #0xf0\n"
1477
+ "ldr q11, [x25, #0x20]\n"
1478
+ "ldr q23, [x25, #0x30]\n"
1479
+ "sshl v21.16b, v6.16b, v17.16b\n"
1480
+ "sshl v16.16b, v5.16b, v17.16b\n"
1481
+ "ldr q20, [x25, #0x40]\n"
1482
+ "ldr q26, [x25, #0x50]\n"
1483
+ "and v6.16b, v6.16b, v22.16b\n"
1484
+ "and v5.16b, v5.16b, v22.16b\n"
1485
+ "ldr q25, [x25, #0x60]\n"
1486
+ "ldr q3, [x25, #0x70]\n"
1487
+ "sshl v19.16b, v31.16b, v17.16b\n"
1488
+ "sshl v18.16b, v14.16b, v17.16b\n"
1489
+ "ldr d17, [x25, #-0x8]\n"
1490
+ ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
1491
+ ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
1492
+ "and v31.16b, v31.16b, v22.16b\n"
1493
+ ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
1494
+ ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
1495
+ "and v14.16b, v14.16b, v22.16b\n"
1496
+ "sub x20, x24, #0x8\n"
1497
+ "ldr d16, [x20, #0x0]\n"
1498
+ "subs x21, x21, #0x1\n"
1499
+ "add x25, x25, #0x88\n"
1500
+ "fcvtl v17.4s, v17.4h\n"
1501
+ "add x24, x24, #0x48\n"
1502
+ ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
1503
+ ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
1504
+ ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
1505
+ ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
1506
+ "fcvtl v16.4s, v16.4h\n"
1507
+ ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
1508
+ ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
1509
+ "fmul v23.4s, v16.4s, v17.s[0]\n"
1510
+ "fmul v21.4s, v16.4s, v17.s[1]\n"
1511
+ "fmul v1.4s, v16.4s, v17.s[2]\n"
1512
+ "fmul v20.4s, v16.4s, v17.s[3]\n"
1513
+ ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
1514
+ ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
1515
+ ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
1516
+ ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
1517
+ ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
1518
+ ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
1519
+ "uzp1 v19.2d, v8.2d, v27.2d\n"
1520
+ "uzp2 v18.2d, v8.2d, v27.2d\n"
1521
+ "scvtf v19.4s, v19.4s, #0x4\n"
1522
+ "uzp1 v17.2d, v0.2d, v29.2d\n"
1523
+ "uzp2 v16.2d, v0.2d, v29.2d\n"
1524
+ "scvtf v18.4s, v18.4s, #0x4\n"
1525
+ "fmla v2.4s, v19.4s, v23.4s\n"
1526
+ "scvtf v17.4s, v17.4s, #0x4\n"
1527
+ "scvtf v16.4s, v16.4s, #0x4\n"
1528
+ "fmla v10.4s, v18.4s, v21.4s\n"
1529
+ "fmla v12.4s, v17.4s, v1.4s\n"
1530
+ "fmla v28.4s, v16.4s, v20.4s\n"
1531
+ "bgt 7b\n"
1532
+ "mov x20, %x[res_ptr]\n"
1533
+ "cmp x10, #0x1\n"
1534
+ "str q2, [x20, #0x0]\n"
1535
+ "add x20, x20, %x[res_stride]\n"
1536
+ "ble 8f\n"
1537
+ "cmp x10, #0x2\n"
1538
+ "str q10, [x20, #0x0]\n"
1539
+ "add x20, x20, %x[res_stride]\n"
1540
+ "ble 8f\n"
1541
+ "cmp x10, #0x3\n"
1542
+ "str q12, [x20, #0x0]\n"
1543
+ "add x20, x20, %x[res_stride]\n"
1544
+ "ble 8f\n"
1545
+ "str q28, [x20, #0x0]\n"
1546
+ "8:" // Row tail: Accumulator store skip
1547
+ "subs x23, x23, #0x4\n"
1548
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
1549
+ "bne 6b\n"
1550
+ "subs x10, x10, #0x4\n"
1551
+ "add %x[a_ptr], %x[a_ptr], x9\n"
1552
+ "mov %x[res_ptr], x22\n"
1553
+ "bgt 5b\n"
1554
+ "9:" // Row tail: Row loop skip
1555
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
1556
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
1557
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
1558
+ );
1559
+ return;
1560
+ }
1561
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
1562
+ float sumf[4][4];
1563
+ int sumi;
1564
+
1565
+ for (int y = 0; y < nr / 4; y++) {
1566
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1567
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1568
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1569
+ for (int m = 0; m < 4; m++) {
1570
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1571
+ }
1572
+ for (int l = 0; l < nb; l++) {
1573
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1574
+ for (int m = 0; m < 4; m++) {
1575
+ for (int j = 0; j < ncols_interleaved; j++) {
1576
+ sumi = 0;
1577
+ for (int i = 0; i < blocklen; ++i) {
1578
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1579
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1580
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1581
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1582
+ }
1583
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
1584
+ }
1585
+ }
1586
+ }
1587
+ }
1588
+ for (int m = 0; m < 4; m++) {
1589
+ for (int j = 0; j < ncols_interleaved; j++)
1590
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1591
+ }
1592
+ }
1593
+ }
1594
+ }
1595
+
1596
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1597
+ const int qk = QK8_0;
1598
+ const int nb = n / qk;
1599
+ const int ncols_interleaved = 8;
1600
+ const int blocklen = 8;
1601
+
1602
+ assert (n % qk == 0);
1603
+ assert (nr % 4 == 0);
1604
+ assert (nc % ncols_interleaved == 0);
1605
+
1606
+ UNUSED(s);
1607
+ UNUSED(bs);
1608
+ UNUSED(vx);
1609
+ UNUSED(vy);
1610
+ UNUSED(nr);
1611
+ UNUSED(nc);
1612
+ UNUSED(nb);
1613
+ UNUSED(ncols_interleaved);
1614
+ UNUSED(blocklen);
1615
+
1616
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
1617
+ #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
1618
+ if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
1619
+ const void * b_ptr = vx;
1620
+ const void * a_ptr = vy;
1621
+ float * res_ptr = s;
1622
+ size_t res_stride = bs * sizeof(float);
1623
+
1624
+ __asm__ __volatile__(
1625
+ "mov x20, #0x4\n"
1626
+ "mov x13, %x[nr]\n"
1627
+ "mov z28.s, #-0x4\n"
1628
+ "mov x12, #0x88\n"
1629
+ "ptrue p1.b\n"
1630
+ "whilelt p0.s, XZR, x20\n"
1631
+ "cmp x13, #0x10\n"
1632
+ "mul x12, %x[nb], x12\n"
1633
+ "blt 4f\n"
1634
+ "1:" // Row loop
1635
+ "add x11, %x[b_ptr], #0x10\n"
1636
+ "mov x10, %x[nc]\n"
1637
+ "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
1638
+ "2:" // Column loop
1639
+ "add x28, %x[a_ptr], #0x8\n"
1640
+ "mov z24.b, #0x0\n"
1641
+ "mov z15.b, #0x0\n"
1642
+ "mov x27, %x[nb]\n"
1643
+ "add x26, x28, x12\n"
1644
+ "mov z12.b, #0x0\n"
1645
+ "mov z0.b, #0x0\n"
1646
+ "add x25, x26, x12\n"
1647
+ "mov z13.b, #0x0\n"
1648
+ "mov z1.b, #0x0\n"
1649
+ "add x24, x25, x12\n"
1650
+ "mov z20.b, #0x0\n"
1651
+ "mov z25.b, #0x0\n"
1652
+ "mov z11.b, #0x0\n"
1653
+ "mov z16.b, #0x0\n"
1654
+ "mov z19.b, #0x0\n"
1655
+ "mov z26.b, #0x0\n"
1656
+ "mov z8.b, #0x0\n"
1657
+ "mov z29.b, #0x0\n"
1658
+ "mov z27.b, #0x0\n"
1659
+ "mov z10.b, #0x0\n"
1660
+ "3:" // Block loop
1661
+ "ld1b { z30.b }, p1/Z, [x11]\n"
1662
+ "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
1663
+ "mov z18.s, #0x0\n"
1664
+ "mov z7.s, #0x0\n"
1665
+ "ld1rqb { z3.b }, p1/Z, [x28]\n"
1666
+ "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
1667
+ "mov z9.s, #0x0\n"
1668
+ "mov z22.s, #0x0\n"
1669
+ "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
1670
+ "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
1671
+ "sub x20, x11, #0x10\n"
1672
+ "sub x23, x28, #0x8\n"
1673
+ "lsl z31.b, z30.b, #0x4\n"
1674
+ "lsl z6.b, z21.b, #0x4\n"
1675
+ "ld1h { z23.s }, p1/Z, [x20]\n"
1676
+ "sub x22, x26, #0x8\n"
1677
+ "and z30.b, z30.b, #0xf0\n"
1678
+ "and z21.b, z21.b, #0xf0\n"
1679
+ "sub x21, x25, #0x8\n"
1680
+ "sub x20, x24, #0x8\n"
1681
+ "lsl z14.b, z4.b, #0x4\n"
1682
+ "lsl z2.b, z17.b, #0x4\n"
1683
+ "subs x27, x27, #0x1\n"
1684
+ "add x11, x11, #0x90\n"
1685
+ ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n"
1686
+ ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n"
1687
+ "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
1688
+ "and z4.b, z4.b, #0xf0\n"
1689
+ ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
1690
+ ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
1691
+ "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
1692
+ "and z17.b, z17.b, #0xf0\n"
1693
+ "fcvt z23.s, p1/m, z23.h\n"
1694
+ ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n"
1695
+ ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n"
1696
+ "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
1697
+ ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
1698
+ ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
1699
+ "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
1700
+ "fscale z23.s, p1/m, z23.s, z28.s\n"
1701
+ ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n"
1702
+ ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n"
1703
+ "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
1704
+ ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
1705
+ ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
1706
+ "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
1707
+ "add x28, x28, #0x88\n"
1708
+ ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n"
1709
+ ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n"
1710
+ "ld1h { z3.s }, p0/Z, [x23]\n"
1711
+ ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
1712
+ ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
1713
+ "fcvt z3.s, p1/m, z3.h\n"
1714
+ "uzp1 z5.d, z18.d, z7.d\n"
1715
+ "uzp2 z18.d, z18.d, z7.d\n"
1716
+ "mov z3.q, z3.q[0]\n"
1717
+ "uzp1 z7.d, z9.d, z22.d\n"
1718
+ "uzp2 z22.d, z9.d, z22.d\n"
1719
+ "fmul z9.s, z23.s, z3.s[0]\n"
1720
+ "scvtf z5.s, p1/m, z5.s\n"
1721
+ "scvtf z18.s, p1/m, z18.s\n"
1722
+ "scvtf z7.s, p1/m, z7.s\n"
1723
+ "scvtf z22.s, p1/m, z22.s\n"
1724
+ "fmla z24.s, p1/M, z5.s, z9.s\n"
1725
+ "ld1rqb { z5.b }, p1/Z, [x26]\n"
1726
+ "fmul z9.s, z23.s, z3.s[1]\n"
1727
+ "fmla z15.s, p1/M, z18.s, z9.s\n"
1728
+ "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
1729
+ "fmul z9.s, z23.s, z3.s[2]\n"
1730
+ "fmul z3.s, z23.s, z3.s[3]\n"
1731
+ "fmla z12.s, p1/M, z7.s, z9.s\n"
1732
+ "mov z9.s, #0x0\n"
1733
+ "ld1h { z7.s }, p0/Z, [x22]\n"
1734
+ ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
1735
+ "fmla z0.s, p1/M, z22.s, z3.s\n"
1736
+ "mov z22.s, #0x0\n"
1737
+ "ld1h { z3.s }, p0/Z, [x21]\n"
1738
+ ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
1739
+ "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
1740
+ "fcvt z7.s, p1/m, z7.h\n"
1741
+ "fcvt z3.s, p1/m, z3.h\n"
1742
+ ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
1743
+ ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
1744
+ "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
1745
+ "mov z7.q, z7.q[0]\n"
1746
+ "mov z3.q, z3.q[0]\n"
1747
+ ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
1748
+ ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
1749
+ "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
1750
+ ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
1751
+ ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
1752
+ "uzp1 z5.d, z9.d, z22.d\n"
1753
+ "scvtf z5.s, p1/m, z5.s\n"
1754
+ "uzp2 z22.d, z9.d, z22.d\n"
1755
+ "fmul z9.s, z23.s, z7.s[0]\n"
1756
+ "scvtf z22.s, p1/m, z22.s\n"
1757
+ "fmla z13.s, p1/M, z5.s, z9.s\n"
1758
+ "ld1rqb { z9.b }, p1/Z, [x25]\n"
1759
+ "fmul z5.s, z23.s, z7.s[1]\n"
1760
+ "fmla z1.s, p1/M, z22.s, z5.s\n"
1761
+ "mov z5.s, #0x0\n"
1762
+ "mov z22.s, #0x0\n"
1763
+ ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n"
1764
+ ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n"
1765
+ "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
1766
+ ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n"
1767
+ ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n"
1768
+ "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
1769
+ ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n"
1770
+ ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n"
1771
+ "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
1772
+ "add x26, x26, #0x88\n"
1773
+ ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n"
1774
+ ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n"
1775
+ "uzp1 z18.d, z5.d, z22.d\n"
1776
+ "scvtf z18.s, p1/m, z18.s\n"
1777
+ "uzp2 z22.d, z5.d, z22.d\n"
1778
+ "fmul z5.s, z23.s, z7.s[2]\n"
1779
+ "fmul z7.s, z23.s, z7.s[3]\n"
1780
+ "scvtf z22.s, p1/m, z22.s\n"
1781
+ "fmla z20.s, p1/M, z18.s, z5.s\n"
1782
+ "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
1783
+ "ld1h { z5.s }, p0/Z, [x20]\n"
1784
+ "fcvt z5.s, p1/m, z5.h\n"
1785
+ "fmla z25.s, p1/M, z22.s, z7.s\n"
1786
+ "mov z22.s, #0x0\n"
1787
+ "mov z7.s, #0x0\n"
1788
+ ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n"
1789
+ ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
1790
+ "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
1791
+ "mov z5.q, z5.q[0]\n"
1792
+ ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n"
1793
+ ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
1794
+ "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
1795
+ ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n"
1796
+ ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n"
1797
+ "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
1798
+ ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n"
1799
+ ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n"
1800
+ "uzp1 z9.d, z22.d, z7.d\n"
1801
+ "scvtf z9.s, p1/m, z9.s\n"
1802
+ "uzp2 z22.d, z22.d, z7.d\n"
1803
+ "fmul z7.s, z23.s, z3.s[0]\n"
1804
+ "scvtf z22.s, p1/m, z22.s\n"
1805
+ "fmla z11.s, p1/M, z9.s, z7.s\n"
1806
+ "ld1rqb { z9.b }, p1/Z, [x24]\n"
1807
+ "fmul z7.s, z23.s, z3.s[1]\n"
1808
+ "fmla z16.s, p1/M, z22.s, z7.s\n"
1809
+ "mov z22.s, #0x0\n"
1810
+ "mov z7.s, #0x0\n"
1811
+ ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n"
1812
+ ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n"
1813
+ "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
1814
+ ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n"
1815
+ ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n"
1816
+ "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
1817
+ ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n"
1818
+ ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
1819
+ "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
1820
+ "add x25, x25, #0x88\n"
1821
+ ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n"
1822
+ ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n"
1823
+ "uzp1 z18.d, z22.d, z7.d\n"
1824
+ "scvtf z18.s, p1/m, z18.s\n"
1825
+ "uzp2 z7.d, z22.d, z7.d\n"
1826
+ "fmul z22.s, z23.s, z3.s[2]\n"
1827
+ "fmul z3.s, z23.s, z3.s[3]\n"
1828
+ "scvtf z7.s, p1/m, z7.s\n"
1829
+ "fmla z19.s, p1/M, z18.s, z22.s\n"
1830
+ "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
1831
+ "fmul z22.s, z23.s, z5.s[0]\n"
1832
+ "fmla z26.s, p1/M, z7.s, z3.s\n"
1833
+ "mov z3.s, #0x0\n"
1834
+ "mov z7.s, #0x0\n"
1835
+ ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n"
1836
+ ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
1837
+ "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
1838
+ ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n"
1839
+ ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
1840
+ "mov z9.s, #0x0\n"
1841
+ ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n"
1842
+ "mov z31.s, #0x0\n"
1843
+ ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n"
1844
+ "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
1845
+ "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
1846
+ ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n"
1847
+ "fmul z14.s, z23.s, z5.s[1]\n"
1848
+ ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n"
1849
+ "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
1850
+ "fmul z2.s, z23.s, z5.s[2]\n"
1851
+ "fmul z23.s, z23.s, z5.s[3]\n"
1852
+ ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n"
1853
+ ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
1854
+ "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
1855
+ ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n"
1856
+ ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n"
1857
+ "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
1858
+ "add x24, x24, #0x88\n"
1859
+ ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n"
1860
+ ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n"
1861
+ ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n"
1862
+ ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n"
1863
+ "uzp1 z18.d, z3.d, z7.d\n"
1864
+ "uzp2 z5.d, z3.d, z7.d\n"
1865
+ "scvtf z18.s, p1/m, z18.s\n"
1866
+ "uzp1 z6.d, z9.d, z31.d\n"
1867
+ "uzp2 z9.d, z9.d, z31.d\n"
1868
+ "scvtf z5.s, p1/m, z5.s\n"
1869
+ "fmla z8.s, p1/M, z18.s, z22.s\n"
1870
+ "scvtf z6.s, p1/m, z6.s\n"
1871
+ "scvtf z9.s, p1/m, z9.s\n"
1872
+ "fmla z29.s, p1/M, z5.s, z14.s\n"
1873
+ "fmla z27.s, p1/M, z6.s, z2.s\n"
1874
+ "fmla z10.s, p1/M, z9.s, z23.s\n"
1875
+ "bgt 3b\n"
1876
+ "mov x20, %x[res_ptr]\n"
1877
+ "subs x10, x10, #0x8\n"
1878
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
1879
+ "st1w { z24.s }, p1, [x20]\n"
1880
+ "add x20, x20, %x[res_stride]\n"
1881
+ "st1w { z15.s }, p1, [x20]\n"
1882
+ "add x20, x20, %x[res_stride]\n"
1883
+ "st1w { z12.s }, p1, [x20]\n"
1884
+ "add x20, x20, %x[res_stride]\n"
1885
+ "st1w { z0.s }, p1, [x20]\n"
1886
+ "add x20, x20, %x[res_stride]\n"
1887
+ "st1w { z13.s }, p1, [x20]\n"
1888
+ "add x20, x20, %x[res_stride]\n"
1889
+ "st1w { z1.s }, p1, [x20]\n"
1890
+ "add x20, x20, %x[res_stride]\n"
1891
+ "st1w { z20.s }, p1, [x20]\n"
1892
+ "add x20, x20, %x[res_stride]\n"
1893
+ "st1w { z25.s }, p1, [x20]\n"
1894
+ "add x20, x20, %x[res_stride]\n"
1895
+ "st1w { z11.s }, p1, [x20]\n"
1896
+ "add x20, x20, %x[res_stride]\n"
1897
+ "st1w { z16.s }, p1, [x20]\n"
1898
+ "add x20, x20, %x[res_stride]\n"
1899
+ "st1w { z19.s }, p1, [x20]\n"
1900
+ "add x20, x20, %x[res_stride]\n"
1901
+ "st1w { z26.s }, p1, [x20]\n"
1902
+ "add x20, x20, %x[res_stride]\n"
1903
+ "st1w { z8.s }, p1, [x20]\n"
1904
+ "add x20, x20, %x[res_stride]\n"
1905
+ "st1w { z29.s }, p1, [x20]\n"
1906
+ "add x20, x20, %x[res_stride]\n"
1907
+ "st1w { z27.s }, p1, [x20]\n"
1908
+ "add x20, x20, %x[res_stride]\n"
1909
+ "st1w { z10.s }, p1, [x20]\n"
1910
+ "bne 2b\n"
1911
+ "mov x20, #0x4\n"
1912
+ "sub x13, x13, #0x10\n"
1913
+ "cmp x13, #0x10\n"
1914
+ "mov %x[res_ptr], x9\n"
1915
+ "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
1916
+ "bge 1b\n"
1917
+ "4:" // Row loop skip
1918
+ "cbz x13, 9f\n"
1919
+ "5:" // Row tail: Row loop
1920
+ "add x25, %x[b_ptr], #0x10\n"
1921
+ "mov x24, %x[nc]\n"
1922
+ "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
1923
+ "6:" // Row tail: Column loop
1924
+ "mov z24.b, #0x0\n"
1925
+ "mov z15.b, #0x0\n"
1926
+ "add x28, %x[a_ptr], #0x8\n"
1927
+ "mov x22, %x[nb]\n"
1928
+ "mov z12.b, #0x0\n"
1929
+ "mov z0.b, #0x0\n"
1930
+ "7:" // Row tail: Block loop
1931
+ "ld1b { z3.b }, p1/Z, [x25]\n"
1932
+ "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
1933
+ "mov z2.s, #0x0\n"
1934
+ "mov z25.s, #0x0\n"
1935
+ "ld1rqb { z26.b }, p1/Z, [x28]\n"
1936
+ "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
1937
+ "mov z27.s, #0x0\n"
1938
+ "mov z19.s, #0x0\n"
1939
+ "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
1940
+ "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
1941
+ "sub x21, x25, #0x10\n"
1942
+ "sub x20, x28, #0x8\n"
1943
+ "lsl z20.b, z3.b, #0x4\n"
1944
+ "lsl z4.b, z6.b, #0x4\n"
1945
+ "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
1946
+ "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
1947
+ "and z3.b, z3.b, #0xf0\n"
1948
+ "and z6.b, z6.b, #0xf0\n"
1949
+ "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
1950
+ "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
1951
+ "lsl z8.b, z29.b, #0x4\n"
1952
+ "lsl z14.b, z16.b, #0x4\n"
1953
+ "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
1954
+ "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
1955
+ ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n"
1956
+ ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n"
1957
+ "and z29.b, z29.b, #0xf0\n"
1958
+ "ld1h { z17.s }, p1/Z, [x21]\n"
1959
+ ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n"
1960
+ ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n"
1961
+ "and z16.b, z16.b, #0xf0\n"
1962
+ "ld1h { z4.s }, p0/Z, [x20]\n"
1963
+ "subs x22, x22, #0x1\n"
1964
+ "add x28, x28, #0x88\n"
1965
+ "fcvt z17.s, p1/m, z17.h\n"
1966
+ "add x25, x25, #0x90\n"
1967
+ ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n"
1968
+ ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n"
1969
+ "fcvt z4.s, p1/m, z4.h\n"
1970
+ ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n"
1971
+ ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n"
1972
+ "fscale z17.s, p1/m, z17.s, z28.s\n"
1973
+ "mov z4.q, z4.q[0]\n"
1974
+ ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n"
1975
+ ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n"
1976
+ "fmul z23.s, z17.s, z4.s[0]\n"
1977
+ "fmul z9.s, z17.s, z4.s[1]\n"
1978
+ "fmul z21.s, z17.s, z4.s[2]\n"
1979
+ "fmul z4.s, z17.s, z4.s[3]\n"
1980
+ ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n"
1981
+ ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n"
1982
+ ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n"
1983
+ ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n"
1984
+ ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n"
1985
+ ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n"
1986
+ "uzp1 z31.d, z2.d, z25.d\n"
1987
+ "uzp2 z13.d, z2.d, z25.d\n"
1988
+ "scvtf z31.s, p1/m, z31.s\n"
1989
+ "uzp1 z17.d, z27.d, z19.d\n"
1990
+ "uzp2 z18.d, z27.d, z19.d\n"
1991
+ "scvtf z13.s, p1/m, z13.s\n"
1992
+ "fmla z24.s, p1/M, z31.s, z23.s\n"
1993
+ "scvtf z17.s, p1/m, z17.s\n"
1994
+ "scvtf z18.s, p1/m, z18.s\n"
1995
+ "fmla z15.s, p1/M, z13.s, z9.s\n"
1996
+ "fmla z12.s, p1/M, z17.s, z21.s\n"
1997
+ "fmla z0.s, p1/M, z18.s, z4.s\n"
1998
+ "bgt 7b\n"
1999
+ "mov x20, %x[res_ptr]\n"
2000
+ "cmp x13, #0x1\n"
2001
+ "st1w { z24.s }, p1, [x20]\n"
2002
+ "add x20, x20, %x[res_stride]\n"
2003
+ "ble 8f\n"
2004
+ "cmp x13, #0x2\n"
2005
+ "st1w { z15.s }, p1, [x20]\n"
2006
+ "add x20, x20, %x[res_stride]\n"
2007
+ "ble 8f\n"
2008
+ "cmp x13, #0x3\n"
2009
+ "st1w { z12.s }, p1, [x20]\n"
2010
+ "add x20, x20, %x[res_stride]\n"
2011
+ "ble 8f\n"
2012
+ "st1w { z0.s }, p1, [x20]\n"
2013
+ "8:" // Row tail: Accumulator store skip
2014
+ "subs x24, x24, #0x8\n"
2015
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
2016
+ "bne 6b\n"
2017
+ "subs x13, x13, #0x4\n"
2018
+ "add %x[a_ptr], %x[a_ptr], x12\n"
2019
+ "mov %x[res_ptr], x23\n"
2020
+ "bgt 5b\n"
2021
+ "9:" // Row tail: Row loop skip
2022
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2023
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2024
+ : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
2025
+ );
2026
+ return;
2027
+ }
2028
+ #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2029
+
2030
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2031
+ float sumf[4][8];
2032
+ int sumi;
2033
+
2034
+ for (int y = 0; y < nr / 4; y++) {
2035
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2036
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2037
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
2038
+ for (int m = 0; m < 4; m++) {
2039
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
2040
+ }
2041
+ for (int l = 0; l < nb; l++) {
2042
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
2043
+ for (int m = 0; m < 4; m++) {
2044
+ for (int j = 0; j < ncols_interleaved; j++) {
2045
+ sumi = 0;
2046
+ for (int i = 0; i < blocklen; ++i) {
2047
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
2048
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
2049
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
2050
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
2051
+ }
2052
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
2053
+ }
2054
+ }
2055
+ }
2056
+ }
2057
+ for (int m = 0; m < 4; m++) {
2058
+ for (int j = 0; j < ncols_interleaved; j++)
2059
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
2060
+ }
2061
+ }
2062
+ }
2063
+ }
2064
+
2065
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2066
+ const int qk = QK8_0;
2067
+ const int nb = n / qk;
2068
+ const int ncols_interleaved = 4;
2069
+ const int blocklen = 4;
2070
+
2071
+ assert (n % qk == 0);
2072
+ assert (nr % 4 == 0);
2073
+ assert (nc % ncols_interleaved == 0);
2074
+
2075
+ UNUSED(s);
2076
+ UNUSED(bs);
2077
+ UNUSED(vx);
2078
+ UNUSED(vy);
2079
+ UNUSED(nr);
2080
+ UNUSED(nc);
2081
+ UNUSED(nb);
2082
+ UNUSED(ncols_interleaved);
2083
+ UNUSED(blocklen);
2084
+
2085
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2086
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
2087
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
2088
+
2089
+ for (int y = 0; y < nr / 4; y++) {
2090
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2091
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2092
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
2093
+
2094
+ float32x4_t sumf[4];
2095
+ for (int m = 0; m < 4; m++) {
2096
+ sumf[m] = vdupq_n_f32(0);
2097
+ }
2098
+
2099
+ for (int l = 0; l < nb; l++) {
2100
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
2101
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
2102
+
2103
+ int32x4_t sumi_0 = vdupq_n_s32(0);
2104
+ int32x4_t sumi_1 = vdupq_n_s32(0);
2105
+ int32x4_t sumi_2 = vdupq_n_s32(0);
2106
+ int32x4_t sumi_3 = vdupq_n_s32(0);
2107
+
2108
+ for (int k = 0; k < 4; k++) {
2109
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
2110
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
2111
+
2112
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
2113
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
2114
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
2115
+
2116
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
2117
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
2118
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
2119
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
2120
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
2121
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
2122
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
2123
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
2124
+ }
2125
+
2126
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2127
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2128
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2129
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2130
+ }
2131
+
2132
+ for (int m = 0; m < 4; m++) {
2133
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
2134
+ }
2135
+ }
2136
+ }
2137
+ return;
2138
+ }
2139
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2140
+ {
2141
+ float sumf[4][4];
2142
+ int sumi;
2143
+
2144
+ for (int y = 0; y < nr / 4; y++) {
2145
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2146
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2147
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
2148
+ for (int m = 0; m < 4; m++) {
2149
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
2150
+ }
2151
+ for (int l = 0; l < nb; l++) {
2152
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
2153
+ for (int m = 0; m < 4; m++) {
2154
+ for (int j = 0; j < ncols_interleaved; j++) {
2155
+ sumi = 0;
2156
+ for (int i = 0; i < blocklen; ++i) {
2157
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
2158
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
2159
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
2160
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
2161
+ }
2162
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
2163
+ }
2164
+ }
2165
+ }
2166
+ }
2167
+ for (int m = 0; m < 4; m++) {
2168
+ for (int j = 0; j < ncols_interleaved; j++)
2169
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
2170
+ }
2171
+ }
2172
+ }
2173
+ }
2174
+ }
ggml/src/ggml-cpu/arch/loongarch/quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/arch/powerpc/quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/arch/riscv/quants.c ADDED
@@ -0,0 +1,2068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+
7
+ #include "../../quants.h"
8
+ #include "../../ggml-cpu-impl.h"
9
+
10
+ #include <math.h>
11
+ #include <string.h>
12
+ #include <assert.h>
13
+ #include <float.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
26
+ assert(QK8_0 == 32);
27
+ assert(k % QK8_0 == 0);
28
+ const int nb = k / QK8_0;
29
+
30
+ block_q8_0 * GGML_RESTRICT y = vy;
31
+
32
+ #if defined(__riscv_v)
33
+
34
+ size_t vl = QK8_0;
35
+
36
+ for (int i = 0; i < nb; i++) {
37
+ // load elements
38
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
39
+
40
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
41
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
42
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
43
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
44
+
45
+ const float d = amax / ((1 << 7) - 1);
46
+ const float id = d ? 1.0f/d : 0.0f;
47
+
48
+ y[i].d = GGML_FP32_TO_FP16(d);
49
+
50
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
51
+
52
+ // convert to integer
53
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
54
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
55
+
56
+ // store result
57
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
58
+ }
59
+ #else
60
+ GGML_UNUSED(nb);
61
+ // scalar
62
+ quantize_row_q8_0_ref(x, y, k);
63
+ #endif
64
+ }
65
+
66
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
67
+ assert(k % QK8_1 == 0);
68
+ const int nb = k / QK8_1;
69
+
70
+ block_q8_1 * GGML_RESTRICT y = vy;
71
+
72
+ #if defined(__riscv_v)
73
+
74
+ size_t vl = QK8_1;
75
+
76
+ for (int i = 0; i < nb; i++) {
77
+ // load elements
78
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
79
+
80
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
81
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
82
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
83
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
84
+
85
+ const float d = amax / ((1 << 7) - 1);
86
+ const float id = d ? 1.0f/d : 0.0f;
87
+
88
+ y[i].d = GGML_FP32_TO_FP16(d);
89
+
90
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
91
+
92
+ // convert to integer
93
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
94
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
95
+
96
+ // store result
97
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
98
+
99
+ // compute sum for y[i].s
100
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
101
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
102
+
103
+ // set y[i].s
104
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
105
+ y[i].s = GGML_FP32_TO_FP16(sum*d);
106
+ }
107
+
108
+ #else
109
+ GGML_UNUSED(nb);
110
+ // scalar
111
+ quantize_row_q8_1_ref(x, y, k);
112
+ #endif
113
+ }
114
+
115
+ //===================================== Dot products =================================
116
+
117
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
118
+ const int qk = QK8_0;
119
+ const int nb = n / qk;
120
+
121
+ assert(n % qk == 0);
122
+ assert(nrc == 1);
123
+ UNUSED(nrc);
124
+ UNUSED(bx);
125
+ UNUSED(by);
126
+ UNUSED(bs);
127
+
128
+ const block_q4_0 * GGML_RESTRICT x = vx;
129
+ const block_q8_0 * GGML_RESTRICT y = vy;
130
+
131
+ int ib = 0;
132
+ float sumf = 0;
133
+
134
+ #if defined(__riscv_v)
135
+ size_t vl = qk / 2;
136
+
137
+ for (; ib < nb; ++ib) {
138
+ // load elements
139
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
140
+
141
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
142
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
143
+
144
+ // mask and store lower part of x, and then upper part
145
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
146
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
147
+
148
+ vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
149
+ vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
150
+
151
+ // subtract offset
152
+ vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
153
+ vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
154
+
155
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
156
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
157
+
158
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
159
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
160
+
161
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
162
+
163
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
164
+ }
165
+
166
+ #endif
167
+ for (; ib < nb; ++ib) {
168
+ int sumi0 = 0;
169
+ int sumi1 = 0;
170
+
171
+ for (int j = 0; j < qk/2; ++j) {
172
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
173
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
174
+
175
+ sumi0 += (v0 * y[ib].qs[j]);
176
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
177
+ }
178
+
179
+ int sumi = sumi0 + sumi1;
180
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
181
+ }
182
+
183
+ *s = sumf;
184
+ }
185
+
186
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
187
+ const int qk = QK8_1;
188
+ const int nb = n / qk;
189
+
190
+ assert(n % qk == 0);
191
+ assert(nrc == 1);
192
+ UNUSED(nrc);
193
+ UNUSED(bx);
194
+ UNUSED(by);
195
+ UNUSED(bs);
196
+
197
+ const block_q4_1 * GGML_RESTRICT x = vx;
198
+ const block_q8_1 * GGML_RESTRICT y = vy;
199
+
200
+ int ib = 0;
201
+ float sumf = 0;
202
+
203
+ #if defined(__riscv_v)
204
+ size_t vl = qk / 2;
205
+
206
+ for (; ib < nb; ++ib) {
207
+ // load elements
208
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
209
+
210
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
211
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
212
+
213
+ // mask and store lower part of x, and then upper part
214
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
215
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
216
+
217
+ vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
218
+ vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
219
+
220
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
221
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
222
+
223
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
224
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
225
+
226
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
227
+
228
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
229
+ }
230
+
231
+ #endif
232
+ for (; ib < nb; ++ib) {
233
+ int sumi0 = 0;
234
+ int sumi1 = 0;
235
+
236
+ for (int j = 0; j < qk/2; ++j) {
237
+ const int v0 = (x[ib].qs[j] & 0x0F);
238
+ const int v1 = (x[ib].qs[j] >> 4);
239
+
240
+ sumi0 += (v0 * y[ib].qs[j]);
241
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
242
+ }
243
+
244
+ int sumi = sumi0 + sumi1;
245
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
246
+ }
247
+
248
+ *s = sumf;
249
+ }
250
+
251
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
252
+ const int qk = QK8_0;
253
+ const int nb = n / qk;
254
+
255
+ int ib = 0;
256
+ float sumf = 0;
257
+
258
+ assert(n % qk == 0);
259
+ assert(qk == QK5_0);
260
+ assert(nrc == 1);
261
+ UNUSED(nrc);
262
+ UNUSED(bx);
263
+ UNUSED(by);
264
+ UNUSED(bs);
265
+
266
+ const block_q5_0 * GGML_RESTRICT x = vx;
267
+ const block_q8_0 * GGML_RESTRICT y = vy;
268
+
269
+ #if defined(__riscv_v)
270
+ size_t vl;
271
+ size_t vlenb = __riscv_vlenb();
272
+
273
+ for (; ib < nb; ++ib) {
274
+ vl = qk / 2;
275
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
276
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
277
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
278
+ vint8m2_t v0c;
279
+ if (vlenb == 16) {
280
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
281
+ } else {
282
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
283
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
284
+ }
285
+
286
+ vl = qk;
287
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
288
+ qh = __riscv_vmnand_mm_b4(qh, qh, vl);
289
+ vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
290
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
291
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
292
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
293
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
294
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
295
+
296
+ sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi;
297
+ }
298
+
299
+ #endif
300
+ for (; ib < nb; ++ib) {
301
+ uint32_t qh;
302
+ memcpy(&qh, x[ib].qh, sizeof(qh));
303
+
304
+ int sumi0 = 0;
305
+ int sumi1 = 0;
306
+
307
+ for (int j = 0; j < qk/2; ++j) {
308
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
309
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
310
+
311
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
312
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
313
+
314
+ sumi0 += (x0 * y[ib].qs[j]);
315
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
316
+ }
317
+
318
+ int sumi = sumi0 + sumi1;
319
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
320
+ }
321
+
322
+ *s = sumf;
323
+ }
324
+
325
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
326
+ const int qk = QK8_1;
327
+ const int nb = n / qk;
328
+
329
+ int ib = 0;
330
+ float sumf = 0;
331
+
332
+ assert(n % qk == 0);
333
+ assert(qk == QK5_1);
334
+ assert(nrc == 1);
335
+ UNUSED(nrc);
336
+ UNUSED(bx);
337
+ UNUSED(by);
338
+ UNUSED(bs);
339
+
340
+ const block_q5_1 * GGML_RESTRICT x = vx;
341
+ const block_q8_1 * GGML_RESTRICT y = vy;
342
+
343
+ #if defined(__riscv_v)
344
+ size_t vl;
345
+ size_t vlenb = __riscv_vlenb();
346
+
347
+ for (; ib < nb; ++ib) {
348
+ vl = qk / 2;
349
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
350
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
351
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
352
+ vint8m2_t v0c;
353
+ if (vlenb == 16) {
354
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
355
+ } else {
356
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
357
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
358
+ }
359
+
360
+ vl = qk;
361
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
362
+ vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
363
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
364
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
365
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
366
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
367
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
368
+
369
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
370
+ }
371
+
372
+ #endif
373
+ for (; ib < nb; ++ib) {
374
+ uint32_t qh;
375
+ memcpy(&qh, x[ib].qh, sizeof(qh));
376
+
377
+ int sumi0 = 0;
378
+ int sumi1 = 0;
379
+
380
+ for (int j = 0; j < qk/2; ++j) {
381
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
382
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
383
+
384
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
385
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
386
+
387
+ sumi0 += (x0 * y[ib].qs[j]);
388
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
389
+ }
390
+
391
+ int sumi = sumi0 + sumi1;
392
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
393
+ }
394
+
395
+ *s = sumf;
396
+ }
397
+
398
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
399
+ const int qk = QK8_0;
400
+ const int nb = n / qk;
401
+
402
+ assert(n % qk == 0);
403
+ assert(nrc == 1);
404
+ UNUSED(nrc);
405
+ UNUSED(bx);
406
+ UNUSED(by);
407
+ UNUSED(bs);
408
+
409
+ const block_q8_0 * GGML_RESTRICT x = vx;
410
+ const block_q8_0 * GGML_RESTRICT y = vy;
411
+
412
+ int ib = 0;
413
+ float sumf = 0;
414
+
415
+ #if defined(__riscv_v)
416
+ size_t vl = qk;
417
+
418
+ for (; ib < nb; ++ib) {
419
+ // load elements
420
+ vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
421
+ vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
422
+
423
+ vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
424
+
425
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
426
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
427
+
428
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
429
+
430
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
431
+ }
432
+
433
+ #endif
434
+ for (; ib < nb; ++ib) {
435
+ int sumi = 0;
436
+
437
+ for (int j = 0; j < qk; j++) {
438
+ sumi += x[ib].qs[j]*y[ib].qs[j];
439
+ }
440
+
441
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
442
+ }
443
+
444
+ *s = sumf;
445
+ }
446
+
447
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
448
+ assert(nrc == 1);
449
+ UNUSED(nrc);
450
+ UNUSED(bx);
451
+ UNUSED(by);
452
+ UNUSED(bs);
453
+
454
+ const block_q2_K * GGML_RESTRICT x = vx;
455
+ const block_q8_K * GGML_RESTRICT y = vy;
456
+
457
+ const int nb = n / QK_K;
458
+
459
+ #if defined __riscv_xtheadvector
460
+
461
+ float sumf = 0;
462
+ uint8_t atmp[16];
463
+
464
+ for (int i = 0; i < nb; ++i) {
465
+ const uint8_t * q2 = x[i].qs;
466
+ const int8_t * q8 = y[i].qs;
467
+ const uint8_t * sc = x[i].scales;
468
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
469
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
470
+ uint8_t *patmp = atmp;
471
+ int vsums;
472
+ int tmp;
473
+ __asm__ __volatile__(
474
+ "th.vsetvli zero, %[vl16], e8, m1\n\t"
475
+ "th.vmv.v.x v8, zero\n\t"
476
+ "th.vlb.v v1, (%[sc])\n\t"
477
+ "th.vand.vi v0, v1, 0xF\n\t"
478
+ "th.vsrl.vi v1, v1, 4\n\t"
479
+ "th.vsb.v v0, (%[scale])\n\t"
480
+ "th.vwaddu.vx v16, v1, zero\n\t"
481
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
482
+ "th.vlh.v v2, (%[bsums])\n\t"
483
+ "th.vwmul.vv v4, v16, v2\n\t"
484
+ "th.vsetvli zero, %[vl16], e32, m4\n\t"
485
+ "th.vredsum.vs v8, v4, v8\n\t"
486
+ "th.vmv.x.s %[vsums], v8"
487
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
488
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
489
+ , [vl16] "r" (16)
490
+ : "memory"
491
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
492
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
493
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
494
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
495
+ );
496
+ sumf += dmin * vsums;
497
+ int isum = 0;
498
+
499
+ for (int j = 0; j < QK_K/128; ++j) {
500
+ __asm__ __volatile__(
501
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
502
+ "th.vlb.v v0, (%[q2])\n\t"
503
+ "th.vsrl.vi v2, v0, 2\n\t"
504
+ "th.vsrl.vi v4, v0, 4\n\t"
505
+ "th.vsrl.vi v6, v0, 6\n\t"
506
+ "th.vand.vi v0, v0, 0x3\n\t"
507
+ "th.vand.vi v2, v2, 0x3\n\t"
508
+ "th.vand.vi v4, v4, 0x3\n\t"
509
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
510
+ "th.vlb.v v8, (%[q8])\n\t"
511
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
512
+ "th.vwmul.vv v16, v0, v8\n\t"
513
+ "th.vwmul.vv v24, v4, v12\n\t"
514
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
515
+ "th.vmv.v.x v0, zero\n\t"
516
+ "th.vwredsum.vs v10, v16, v0\n\t"
517
+ "th.vwredsum.vs v9, v18, v0\n\t"
518
+ "th.vwredsum.vs v8, v20, v0\n\t"
519
+ "th.vwredsum.vs v7, v22, v0\n\t"
520
+ "th.vwredsum.vs v11, v24, v0\n\t"
521
+ "th.vwredsum.vs v12, v26, v0\n\t"
522
+ "th.vwredsum.vs v13, v28, v0\n\t"
523
+ "th.vwredsum.vs v14, v30, v0\n\t"
524
+ "li %[tmp], 4\n\t"
525
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
526
+ "th.vslideup.vi v10, v9, 1\n\t"
527
+ "th.vslideup.vi v8, v7, 1\n\t"
528
+ "th.vslideup.vi v11, v12, 1\n\t"
529
+ "th.vslideup.vi v13, v14, 1\n\t"
530
+ "th.vslideup.vi v10, v8, 2\n\t"
531
+ "th.vslideup.vi v11, v13, 2\n\t"
532
+ "li %[tmp], 8\n\t"
533
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
534
+ "th.vlbu.v v12, (%[scale])\n\t"
535
+ "th.vmul.vv v10, v10, v12\n\t"
536
+ "th.vredsum.vs v0, v10, v0\n\t"
537
+ "th.vmv.x.s %[tmp], v0\n\t"
538
+ "add %[isum], %[isum], %[tmp]"
539
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
540
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
541
+ , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
542
+ : "memory"
543
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
544
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
545
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
546
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
547
+ );
548
+ q2 += 32; q8 += 128; patmp += 8;
549
+ }
550
+
551
+ sumf += dall * isum;
552
+ }
553
+
554
+ *s = sumf;
555
+
556
+ #elif defined __riscv_v
557
+
558
+ float sumf = 0;
559
+ uint8_t atmp[16];
560
+
561
+ const int vector_length = __riscv_vlenb() * 8;
562
+ uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
563
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
564
+
565
+ switch (vector_length) {
566
+ case 256:
567
+ for (int i = 0; i < nb; ++i) {
568
+ const uint8_t * q2 = x[i].qs;
569
+ const int8_t * q8 = y[i].qs;
570
+ const uint8_t * sc = x[i].scales;
571
+
572
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
573
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
574
+
575
+ size_t vl = 16;
576
+
577
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
578
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
579
+
580
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
581
+
582
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
583
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
584
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
585
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
586
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
587
+
588
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
589
+
590
+ vl = 32;
591
+
592
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
593
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
594
+
595
+ uint8_t is = 0;
596
+ int isum = 0;
597
+
598
+ for (int j = 0; j < QK_K / 128; ++j) {
599
+ // load Q2
600
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
601
+
602
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
603
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
604
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
605
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
606
+
607
+ // duplicate scale elements for product
608
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
609
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
610
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
611
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
612
+
613
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
614
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
615
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
616
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
617
+
618
+ // load Q8
619
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
620
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
621
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
622
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
623
+
624
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
625
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
626
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
627
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
628
+
629
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
630
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
631
+
632
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
633
+
634
+ q2 += 32;
635
+ q8 += 128;
636
+ is = 8;
637
+ }
638
+
639
+ sumf += dall * isum;
640
+ }
641
+ break;
642
+ case 128:
643
+ for (int i = 0; i < nb; ++i) {
644
+ const uint8_t * q2 = x[i].qs;
645
+ const int8_t * q8 = y[i].qs;
646
+ const uint8_t * sc = x[i].scales;
647
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
648
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
649
+ uint8_t *patmp = atmp;
650
+ int vsums;
651
+ int tmp;
652
+ __asm__ __volatile__(
653
+ "vsetivli zero, 16, e8, m1\n\t"
654
+ "vmv.v.x v8, zero\n\t"
655
+ "vle8.v v1, (%[sc])\n\t"
656
+ "vand.vi v0, v1, 0xF\n\t"
657
+ "vsrl.vi v1, v1, 4\n\t"
658
+ "vse8.v v0, (%[scale])\n\t"
659
+ "vsetivli zero, 16, e16, m2\n\t"
660
+ "vle16.v v2, (%[bsums])\n\t"
661
+ "vzext.vf2 v0, v1\n\t"
662
+ "vwmul.vv v4, v0, v2\n\t"
663
+ "vsetivli zero, 16, e32, m4\n\t"
664
+ "vredsum.vs v8, v4, v8\n\t"
665
+ "vmv.x.s %[vsums], v8"
666
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
667
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
668
+ : "memory"
669
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
670
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
671
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
672
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
673
+ );
674
+ sumf += dmin * vsums;
675
+ int isum = 0;
676
+
677
+ for (int j = 0; j < QK_K/128; ++j) {
678
+ __asm__ __volatile__(
679
+ "vsetvli zero, %[vl32], e8, m2\n\t"
680
+ "vle8.v v0, (%[q2])\n\t"
681
+ "vsrl.vi v2, v0, 2\n\t"
682
+ "vsrl.vi v4, v0, 4\n\t"
683
+ "vsrl.vi v6, v0, 6\n\t"
684
+ "vand.vi v0, v0, 0x3\n\t"
685
+ "vand.vi v2, v2, 0x3\n\t"
686
+ "vand.vi v4, v4, 0x3\n\t"
687
+ "vsetvli zero, %[vl128], e8, m8\n\t"
688
+ "vle8.v v8, (%[q8])\n\t"
689
+ "vsetvli zero, %[vl64], e8, m4\n\t"
690
+ "vwmul.vv v16, v0, v8\n\t"
691
+ "vwmul.vv v24, v4, v12\n\t"
692
+ "vsetivli zero, 16, e16, m2\n\t"
693
+ "vmv.v.x v0, zero\n\t"
694
+ "vwredsum.vs v10, v16, v0\n\t"
695
+ "vwredsum.vs v9, v18, v0\n\t"
696
+ "vwredsum.vs v8, v20, v0\n\t"
697
+ "vwredsum.vs v7, v22, v0\n\t"
698
+ "vwredsum.vs v11, v24, v0\n\t"
699
+ "vwredsum.vs v12, v26, v0\n\t"
700
+ "vwredsum.vs v13, v28, v0\n\t"
701
+ "vwredsum.vs v14, v30, v0\n\t"
702
+ "vsetivli zero, 4, e32, m1\n\t"
703
+ "vslideup.vi v10, v9, 1\n\t"
704
+ "vslideup.vi v8, v7, 1\n\t"
705
+ "vslideup.vi v11, v12, 1\n\t"
706
+ "vslideup.vi v13, v14, 1\n\t"
707
+ "vslideup.vi v10, v8, 2\n\t"
708
+ "vslideup.vi v11, v13, 2\n\t"
709
+ "vsetivli zero, 8, e32, m2\n\t"
710
+ "vle8.v v15, (%[scale])\n\t"
711
+ "vzext.vf4 v12, v15\n\t"
712
+ "vmul.vv v10, v10, v12\n\t"
713
+ "vredsum.vs v0, v10, v0\n\t"
714
+ "vmv.x.s %[tmp], v0\n\t"
715
+ "add %[isum], %[isum], %[tmp]"
716
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
717
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
718
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
719
+ : "memory"
720
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
721
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
722
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
723
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
724
+ );
725
+ q2 += 32; q8 += 128; patmp += 8;
726
+ }
727
+
728
+ sumf += dall * isum;
729
+ }
730
+ break;
731
+ default:
732
+ assert(false && "Unsupported vector length");
733
+ break;
734
+ }
735
+
736
+ *s = sumf;
737
+
738
+ #else
739
+
740
+ float sumf = 0;
741
+
742
+ for (int i = 0; i < nb; ++i) {
743
+
744
+ const uint8_t * q2 = x[i].qs;
745
+ const int8_t * q8 = y[i].qs;
746
+ const uint8_t * sc = x[i].scales;
747
+
748
+ int summs = 0;
749
+ for (int j = 0; j < 16; ++j) {
750
+ summs += y[i].bsums[j] * (sc[j] >> 4);
751
+ }
752
+
753
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
754
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
755
+
756
+ int isum = 0;
757
+ int is = 0;
758
+ int d;
759
+ for (int k = 0; k < QK_K/128; ++k) {
760
+ int shift = 0;
761
+ for (int j = 0; j < 4; ++j) {
762
+ d = sc[is++] & 0xF;
763
+ int isuml = 0;
764
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
765
+ isum += d * isuml;
766
+ d = sc[is++] & 0xF;
767
+ isuml = 0;
768
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
769
+ isum += d * isuml;
770
+ shift += 2;
771
+ q8 += 32;
772
+ }
773
+ q2 += 32;
774
+ }
775
+ sumf += dall * isum - dmin * summs;
776
+ }
777
+ *s = sumf;
778
+ #endif
779
+ }
780
+
781
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
782
+ assert(n % QK_K == 0);
783
+ assert(nrc == 1);
784
+ UNUSED(nrc);
785
+ UNUSED(bx);
786
+ UNUSED(by);
787
+ UNUSED(bs);
788
+
789
+ const uint32_t kmask1 = 0x03030303;
790
+ const uint32_t kmask2 = 0x0f0f0f0f;
791
+
792
+ const block_q3_K * GGML_RESTRICT x = vx;
793
+ const block_q8_K * GGML_RESTRICT y = vy;
794
+
795
+ const int nb = n / QK_K;
796
+
797
+ #if defined __riscv_xtheadvector
798
+
799
+ uint32_t utmp[4];
800
+ float sumf = 0;
801
+
802
+ for (int i = 0; i < nb; ++i) {
803
+ const uint8_t * restrict q3 = x[i].qs;
804
+ const uint8_t * restrict qh = x[i].hmask;
805
+ const int8_t * restrict q8 = y[i].qs;
806
+
807
+ int8_t * scale = (int8_t *)utmp;
808
+ int tmp;
809
+ __asm__ __volatile__(
810
+ "li %[tmp], 12\n\t"
811
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
812
+ "th.vlb.v v0, (%[s6b])\n\t"
813
+ "th.vmv.v.v v2, v0\n\t"
814
+ "li %[tmp], 2\n\t"
815
+ "th.vsetvli zero, %[tmp], e64, m1\n\t"
816
+ "th.vmv.v.x v9, %[sh]\n\t"\
817
+ "th.vslidedown.vi v1, v0, 1\n\t"
818
+ "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
819
+ "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
820
+ "li %[tmp], 4\n\t"
821
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
822
+ "th.vid.v v9\n\t"
823
+ "th.vmv.x.s %[tmp], v1\n\t"
824
+ "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
825
+ "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
826
+ "th.vsrl.vv v4, v1, v9\n\t"
827
+ "th.vsrl.vv v2, v0, v8\n\t"
828
+ "th.vand.vx v5, v4, %[kmask1]\n\t"
829
+ "th.vand.vx v3, v2, %[kmask2]\n\t"
830
+ "th.vsll.vi v6, v5, 4\n\t"
831
+ "th.vor.vv v7, v6, v3\n\t"
832
+ "li %[tmp], 16\n\t"
833
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
834
+ "th.vsub.vx v0, v7, %[c]\n\t"
835
+ "th.vsb.v v0, (%[scale])"
836
+ : [tmp] "=&r" (tmp)
837
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
838
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
839
+ : "memory"
840
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
841
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
842
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
843
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
844
+ );
845
+
846
+ uint8_t m = 1;
847
+ int isum = 0;
848
+ for (int j = 0; j < QK_K; j += 128) {
849
+ __asm__ __volatile__(
850
+ // fixme: use v0p7 mask layout directly
851
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
852
+ "th.vlb.v v8, (%[q3])\n\t"
853
+ "th.vsrl.vi v10, v8, 2\n\t"
854
+ "th.vsrl.vi v12, v8, 4\n\t"
855
+ "th.vsrl.vi v14, v8, 6\n\t"
856
+ "th.vand.vi v8, v8, 3\n\t"
857
+ "th.vand.vi v10, v10, 3\n\t"
858
+ "th.vand.vi v12, v12, 3\n\t"
859
+ "th.vlb.v v2, (%[qh])\n\t"
860
+ "th.vand.vx v4, v2, %[m]\n\t"
861
+ "slli %[m], %[m], 1\n\t"
862
+ "th.vmseq.vx v0, v4, zero\n\t"
863
+ "th.vadd.vi v8, v8, -4, v0.t\n\t"
864
+ "th.vand.vx v4, v2, %[m]\n\t"
865
+ "slli %[m], %[m], 1\n\t"
866
+ "th.vmseq.vx v0, v4, zero\n\t"
867
+ "th.vadd.vi v10, v10, -4, v0.t\n\t"
868
+ "th.vand.vx v4, v2, %[m]\n\t"
869
+ "slli %[m], %[m], 1\n\t"
870
+ "th.vmseq.vx v0, v4, zero\n\t"
871
+ "th.vadd.vi v12, v12, -4, v0.t\n\t"
872
+ "th.vand.vx v4, v2, %[m]\n\t"
873
+ "slli %[m], %[m], 1\n\t"
874
+ "th.vmseq.vx v0, v4, zero\n\t"
875
+ "th.vadd.vi v14, v14, -4, v0.t\n\t"
876
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
877
+ "th.vlb.v v0, (%[q8])\n\t"
878
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
879
+ "th.vwmul.vv v16, v0, v8\n\t"
880
+ "th.vwmul.vv v24, v4, v12\n\t"
881
+ "li %[tmp], 16\n\t"
882
+ "th.vsetvli zero, %[tmp], e16, m2\n\t"
883
+ "th.vmv.v.x v0, zero\n\t"
884
+ "th.vwredsum.vs v10, v16, v0\n\t"
885
+ "th.vwredsum.vs v9, v18, v0\n\t"
886
+ "th.vwredsum.vs v8, v20, v0\n\t"
887
+ "th.vwredsum.vs v7, v22, v0\n\t"
888
+ "th.vwredsum.vs v11, v24, v0\n\t"
889
+ "th.vwredsum.vs v12, v26, v0\n\t"
890
+ "th.vwredsum.vs v13, v28, v0\n\t"
891
+ "th.vwredsum.vs v14, v30, v0\n\t"
892
+ "li %[tmp], 4\n\t"
893
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
894
+ "th.vslideup.vi v10, v9, 1\n\t"
895
+ "th.vslideup.vi v8, v7, 1\n\t"
896
+ "th.vslideup.vi v11, v12, 1\n\t"
897
+ "th.vslideup.vi v13, v14, 1\n\t"
898
+ "th.vslideup.vi v10, v8, 2\n\t"
899
+ "th.vslideup.vi v11, v13, 2\n\t"
900
+ "li %[tmp], 8\n\t"
901
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
902
+ "th.vlb.v v12, (%[scale])\n\t"
903
+ "th.vmul.vv v10, v10, v12\n\t"
904
+ "th.vredsum.vs v0, v10, v0\n\t"
905
+ "th.vmv.x.s %[tmp], v0\n\t"
906
+ "add %[isum], %[isum], %[tmp]"
907
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
908
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
909
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
910
+ : "memory"
911
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
912
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
913
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
914
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
915
+ );
916
+ q3 += 32; q8 += 128; scale += 8;
917
+ }
918
+
919
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
920
+ sumf += d * isum;
921
+ }
922
+
923
+ *s = sumf;
924
+
925
+ #elif defined __riscv_v
926
+
927
+ uint32_t utmp[4];
928
+ float sumf = 0;
929
+ uint32_t aux[3];
930
+ const int vector_length = __riscv_vlenb() * 8;
931
+
932
+ switch (vector_length) {
933
+ case 256:
934
+ for (int i = 0; i < nb; ++i) {
935
+
936
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
937
+ const uint8_t * GGML_RESTRICT qh = x[i].hmask;
938
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
939
+
940
+ memcpy(aux, x[i].scales, 12);
941
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
942
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
943
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
944
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
945
+
946
+ int8_t * scale = (int8_t *)utmp;
947
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
948
+
949
+
950
+ size_t vl = 32;
951
+ uint8_t m = 1;
952
+
953
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
954
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
955
+
956
+ int sum_t = 0;
957
+
958
+ for (int j = 0; j < QK_K; j += 128) {
959
+
960
+ vl = 32;
961
+
962
+ // load Q3
963
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
964
+
965
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
966
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
967
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
968
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
969
+
970
+ // compute mask for subtraction
971
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
972
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
973
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
974
+ m <<= 1;
975
+
976
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
977
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
978
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
979
+ m <<= 1;
980
+
981
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
982
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
983
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
984
+ m <<= 1;
985
+
986
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
987
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
988
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
989
+ m <<= 1;
990
+
991
+ // load Q8 and take product with Q3
992
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
993
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
994
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
995
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
996
+
997
+ vl = 16;
998
+
999
+ // retrieve lane to multiply with scale
1000
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
1001
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
1002
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
1003
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
1004
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
1005
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
1006
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
1007
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
1008
+
1009
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
1010
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
1011
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
1012
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
1013
+
1014
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1015
+
1016
+ q3 += 32; q8 += 128; scale += 8;
1017
+
1018
+ }
1019
+
1020
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1021
+
1022
+ sumf += d*sum_t;
1023
+
1024
+ }
1025
+ break;
1026
+ case 128:
1027
+ for (int i = 0; i < nb; ++i) {
1028
+ const uint8_t * restrict q3 = x[i].qs;
1029
+ const uint8_t * restrict qh = x[i].hmask;
1030
+ const int8_t * restrict q8 = y[i].qs;
1031
+
1032
+ int8_t * scale = (int8_t *)utmp;
1033
+ int tmp;
1034
+ __asm__ __volatile__(
1035
+ "vsetivli zero, 12, e8, m1\n\t"
1036
+ "vle8.v v0, (%[s6b])\n\t"
1037
+ "vmv1r.v v2, v0\n\t"
1038
+ "vsetivli zero, 2, e64, m1\n\t"
1039
+ "vmv.v.x v9, %[sh]\n\t"\
1040
+ "vslidedown.vi v1, v0, 1\n\t"
1041
+ "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
1042
+ "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
1043
+ "vsetivli zero, 4, e32, m1\n\t"
1044
+ "vid.v v9\n\t"
1045
+ "vmv.x.s %[tmp], v1\n\t"
1046
+ "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
1047
+ "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
1048
+ "vsrl.vv v4, v1, v9\n\t"
1049
+ "vsrl.vv v2, v0, v8\n\t"
1050
+ "vand.vx v5, v4, %[kmask1]\n\t"
1051
+ "vand.vx v3, v2, %[kmask2]\n\t"
1052
+ "vsll.vi v6, v5, 4\n\t"
1053
+ "vor.vv v7, v6, v3\n\t"
1054
+ "vsetivli zero, 16, e8, m1\n\t"
1055
+ "vsub.vx v0, v7, %[c]\n\t"
1056
+ "vse8.v v0, (%[scale])"
1057
+ : [tmp] "=&r" (tmp)
1058
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
1059
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
1060
+ : "memory"
1061
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1062
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1063
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1064
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1065
+ );
1066
+
1067
+ uint8_t m = 1;
1068
+ int isum = 0;
1069
+ for (int j = 0; j < QK_K; j += 128) {
1070
+ __asm__ __volatile__(
1071
+ "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
1072
+ "vle8.v v8, (%[q3])\n\t"
1073
+ "vsrl.vi v10, v8, 2\n\t"
1074
+ "vsrl.vi v12, v8, 4\n\t"
1075
+ "vsrl.vi v14, v8, 6\n\t"
1076
+ "vand.vi v8, v8, 3\n\t"
1077
+ "vand.vi v10, v10, 3\n\t"
1078
+ "vand.vi v12, v12, 3\n\t"
1079
+ "vle8.v v2, (%[qh])\n\t"
1080
+ "vand.vx v4, v2, %[m]\n\t"
1081
+ "slli %[m], %[m], 1\n\t"
1082
+ "vmseq.vx v0, v4, zero\n\t"
1083
+ "vadd.vi v8, v8, -4, v0.t\n\t"
1084
+ "vand.vx v4, v2, %[m]\n\t"
1085
+ "slli %[m], %[m], 1\n\t"
1086
+ "vmseq.vx v0, v4, zero\n\t"
1087
+ "vadd.vi v10, v10, -4, v0.t\n\t"
1088
+ "vand.vx v4, v2, %[m]\n\t"
1089
+ "slli %[m], %[m], 1\n\t"
1090
+ "vmseq.vx v0, v4, zero\n\t"
1091
+ "vadd.vi v12, v12, -4, v0.t\n\t"
1092
+ "vand.vx v4, v2, %[m]\n\t"
1093
+ "slli %[m], %[m], 1\n\t"
1094
+ "vmseq.vx v0, v4, zero\n\t"
1095
+ "vadd.vi v14, v14, -4, v0.t\n\t"
1096
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1097
+ "vle8.v v0, (%[q8])\n\t"
1098
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1099
+ "vwmul.vv v16, v0, v8\n\t"
1100
+ "vwmul.vv v24, v4, v12\n\t"
1101
+ "vsetivli zero, 16, e16, m2\n\t"
1102
+ "vmv.v.x v0, zero\n\t"
1103
+ "vwredsum.vs v10, v16, v0\n\t"
1104
+ "vwredsum.vs v9, v18, v0\n\t"
1105
+ "vwredsum.vs v8, v20, v0\n\t"
1106
+ "vwredsum.vs v7, v22, v0\n\t"
1107
+ "vwredsum.vs v11, v24, v0\n\t"
1108
+ "vwredsum.vs v12, v26, v0\n\t"
1109
+ "vwredsum.vs v13, v28, v0\n\t"
1110
+ "vwredsum.vs v14, v30, v0\n\t"
1111
+ "vsetivli zero, 4, e32, m1\n\t"
1112
+ "vslideup.vi v10, v9, 1\n\t"
1113
+ "vslideup.vi v8, v7, 1\n\t"
1114
+ "vslideup.vi v11, v12, 1\n\t"
1115
+ "vslideup.vi v13, v14, 1\n\t"
1116
+ "vslideup.vi v10, v8, 2\n\t"
1117
+ "vslideup.vi v11, v13, 2\n\t"
1118
+ "vsetivli zero, 8, e32, m2\n\t"
1119
+ "vle8.v v15, (%[scale])\n\t"
1120
+ "vsext.vf4 v12, v15\n\t"
1121
+ "vmul.vv v10, v10, v12\n\t"
1122
+ "vredsum.vs v0, v10, v0\n\t"
1123
+ "vmv.x.s %[tmp], v0\n\t"
1124
+ "add %[isum], %[isum], %[tmp]"
1125
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
1126
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
1127
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
1128
+ : "memory"
1129
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1130
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1131
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1132
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1133
+ );
1134
+ q3 += 32; q8 += 128; scale += 8;
1135
+ }
1136
+
1137
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1138
+ sumf += d * isum;
1139
+ }
1140
+ break;
1141
+ default:
1142
+ assert(false && "Unsupported vector length");
1143
+ break;
1144
+ }
1145
+
1146
+ *s = sumf;
1147
+
1148
+ #else
1149
+ // scalar version
1150
+ // This function is written like this so the compiler can manage to vectorize most of it
1151
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1152
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1153
+ // The ideal situation would be if we could just write the code once, and the compiler would
1154
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
1155
+ // write vectorized versions for AVX, ARM_NEON, etc.
1156
+
1157
+ int8_t aux8[QK_K];
1158
+ int16_t aux16[8];
1159
+ float sums [8];
1160
+ int32_t aux32[8];
1161
+ memset(sums, 0, 8*sizeof(float));
1162
+
1163
+ uint32_t auxs[4];
1164
+ const int8_t * scales = (const int8_t*)auxs;
1165
+
1166
+ float sumf = 0;
1167
+ for (int i = 0; i < nb; ++i) {
1168
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1169
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
1170
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1171
+ memset(aux32, 0, 8*sizeof(int32_t));
1172
+ int8_t * GGML_RESTRICT a = aux8;
1173
+ uint8_t m = 1;
1174
+ for (int j = 0; j < QK_K; j += 128) {
1175
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1176
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1177
+ a += 32; m <<= 1;
1178
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
1179
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1180
+ a += 32; m <<= 1;
1181
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
1182
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1183
+ a += 32; m <<= 1;
1184
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
1185
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1186
+ a += 32; m <<= 1;
1187
+ q3 += 32;
1188
+ }
1189
+ a = aux8;
1190
+
1191
+ memcpy(auxs, x[i].scales, 12);
1192
+ uint32_t tmp = auxs[2];
1193
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1194
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1195
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1196
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1197
+ for (int j = 0; j < QK_K/16; ++j) {
1198
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1199
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1200
+ q8 += 8; a += 8;
1201
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1202
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1203
+ q8 += 8; a += 8;
1204
+ }
1205
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1206
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1207
+ }
1208
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1209
+ *s = sumf;
1210
+
1211
+ #endif
1212
+
1213
+ }
1214
+
1215
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1216
+ assert(n % QK_K == 0);
1217
+ assert(nrc == 1);
1218
+ UNUSED(nrc);
1219
+ UNUSED(bx);
1220
+ UNUSED(by);
1221
+ UNUSED(bs);
1222
+
1223
+ const block_q4_K * GGML_RESTRICT x = vx;
1224
+ const block_q8_K * GGML_RESTRICT y = vy;
1225
+
1226
+ const int nb = n / QK_K;
1227
+
1228
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1229
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1230
+ static const uint32_t kmask3 = 0x03030303;
1231
+
1232
+ uint32_t utmp[4];
1233
+
1234
+ #if defined __riscv_xtheadvector
1235
+
1236
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1237
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1238
+
1239
+ float sumf = 0;
1240
+
1241
+ for (int i = 0; i < nb; ++i) {
1242
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1243
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1244
+
1245
+ int tmp, tmp2, sumi;
1246
+ __asm__ __volatile__(
1247
+ "li %[t1], 12\n\t"
1248
+ "th.vsetvli zero, %[t1], e8, m1\n\t"
1249
+ "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1250
+ "li %[t1], 4\n\t"
1251
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1252
+ "th.vslidedown.vi v2, v1, 2\n\t"
1253
+ "th.vmv.v.v v3, v2\n\t"
1254
+ "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1255
+ "li %[t1], 2\n\t"
1256
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1257
+ "th.vmv.v.i v4, 4\n\t"
1258
+ "th.vand.vx v8, v1, %[kmask1]\n\t"
1259
+ "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1260
+ "th.vsrl.vi v6, v1, 6\n\t"
1261
+ "th.vsrl.vv v7, v2, v5\n\t"
1262
+ "th.vand.vx v0, v6, %[kmask3]\n\t"
1263
+ "th.vand.vx v2, v7, %[kmask2]\n\t"
1264
+ "th.vsll.vi v6, v0, 4\n\t"
1265
+ "li %[t2], 8\n\t"
1266
+ "addi %[t1], %[utmp], 4\n\t"
1267
+ "th.vor.vv v1, v6, v2\n\t"
1268
+ "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
1269
+ "th.vssw.v v1, (%[t1]), %[t2]\n\t"
1270
+ "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
1271
+ "th.vlw.v v2, (%[bsums])\n\t"
1272
+ "th.vsetvli zero, %[t2], e16, m1\n\t"
1273
+ "th.vnsrl.vi v0, v2, 0\n\t"
1274
+ "th.vnsrl.vi v1, v2, 16\n\t"
1275
+ "th.vadd.vv v2, v0, v1\n\t"
1276
+ "th.vlbu.v v4, (%[mins])\n\t"
1277
+ "th.vwmul.vv v6, v4, v2\n\t"
1278
+ "th.vmv.v.x v0, zero\n\t"
1279
+ "th.vsetvli zero, %[t2], e32, m2\n\t"
1280
+ "th.vredsum.vs v0, v6, v0\n\t"
1281
+ "th.vmv.x.s %[sumi], v0"
1282
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1283
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1284
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1285
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1286
+ : "memory"
1287
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1288
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1289
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1290
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1291
+ );
1292
+ sumf -= dmin * sumi;
1293
+
1294
+ const uint8_t * restrict q4 = x[i].qs;
1295
+ const int8_t * restrict q8 = y[i].qs;
1296
+
1297
+ sumi = 0;
1298
+ const uint8_t * scale = scales;
1299
+
1300
+ for (int j = 0; j < QK_K/128; ++j) {
1301
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1302
+ __asm__ __volatile__(
1303
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
1304
+ "th.vlb.v v8, (%[q8])\n\t"
1305
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
1306
+ "th.vlb.v v0, (%[q4])\n\t"
1307
+ "th.vsrl.vi v4, v0, 4\n\t"
1308
+ "th.vand.vi v0, v0, 0xF\n\t"
1309
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
1310
+ "th.vwmul.vv v28, v6, v14\n\t"
1311
+ "th.vwmul.vv v20, v4, v10\n\t"
1312
+ "th.vwmul.vv v24, v2, v12\n\t"
1313
+ "th.vwmul.vv v16, v0, v8\n\t"
1314
+ "li %[tmp], 4\n\t"
1315
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1316
+ "th.vlbu.v v1, (%[scale])\n\t"
1317
+ "th.vmv.v.x v0, zero\n\t"
1318
+ "th.vsetvli zero, %[vl32], e16, m4\n\t"
1319
+ "th.vwredsum.vs v6, v24, v0\n\t"
1320
+ "th.vwredsum.vs v7, v28, v0\n\t"
1321
+ "th.vwredsum.vs v4, v16, v0\n\t"
1322
+ "th.vwredsum.vs v5, v20, v0\n\t"
1323
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1324
+ "th.vslideup.vi v6, v7, 1\n\t"
1325
+ "th.vslideup.vi v4, v5, 1\n\t"
1326
+ "th.vslideup.vi v4, v6, 2\n\t"
1327
+ "th.vmul.vv v8, v4, v1\n\t"
1328
+ "th.vredsum.vs v0, v8, v0\n\t"
1329
+ "th.vmv.x.s %[tmp], v0\n\t"
1330
+ "add %[sumi], %[sumi], %[tmp]"
1331
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1332
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1333
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1334
+ : "memory"
1335
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1336
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1337
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1338
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1339
+ );
1340
+
1341
+ q4 += 64; q8 += 128; scale += 4;
1342
+ }
1343
+
1344
+ sumf += d * sumi;
1345
+
1346
+ }
1347
+
1348
+ *s = sumf;
1349
+
1350
+ #elif defined __riscv_v
1351
+
1352
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1353
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1354
+
1355
+ float sumf = 0;
1356
+ const int vector_length = __riscv_vlenb() * 8;
1357
+
1358
+ switch (vector_length) {
1359
+ case 256:
1360
+ for (int i = 0; i < nb; ++i) {
1361
+
1362
+ size_t vl = 8;
1363
+
1364
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1365
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1366
+
1367
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
1368
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
1369
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
1370
+
1371
+ memcpy(utmp, x[i].scales, 12);
1372
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1373
+ const uint32_t uaux = utmp[1] & kmask1;
1374
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1375
+ utmp[2] = uaux;
1376
+ utmp[0] &= kmask1;
1377
+
1378
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
1379
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
1380
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
1381
+
1382
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1383
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1384
+
1385
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1386
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1387
+
1388
+ vl = 32;
1389
+
1390
+ int32_t sum_1 = 0;
1391
+ int32_t sum_2 = 0;
1392
+
1393
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1394
+
1395
+ for (int j = 0; j < QK_K/64; ++j) {
1396
+ // load Q4
1397
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
1398
+
1399
+ // load Q8 and multiply it with lower Q4 nibble
1400
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1401
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
1402
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
1403
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
1404
+
1405
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
1406
+
1407
+ // load Q8 and multiply it with upper Q4 nibble
1408
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1409
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
1410
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
1411
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
1412
+
1413
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
1414
+
1415
+ q4 += 32; q8 += 64;
1416
+
1417
+ }
1418
+
1419
+ sumf += d*(sum_1 + sum_2);
1420
+
1421
+ }
1422
+ break;
1423
+ case 128:
1424
+ for (int i = 0; i < nb; ++i) {
1425
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1426
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1427
+
1428
+ int tmp, tmp2, sumi;
1429
+ __asm__ __volatile__(
1430
+ "vsetivli zero, 12, e8, m1\n\t"
1431
+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1432
+ "vsetivli zero, 4, e32, m1\n\t"
1433
+ "vslidedown.vi v2, v1, 2\n\t"
1434
+ "vmv1r.v v3, v2\n\t"
1435
+ "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1436
+ "vsetivli zero, 2, e32, m1\n\t"
1437
+ "vmv.v.i v4, 4\n\t"
1438
+ "vand.vx v8, v1, %[kmask1]\n\t"
1439
+ "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1440
+ "vsrl.vi v6, v1, 6\n\t"
1441
+ "vsrl.vv v7, v2, v5\n\t"
1442
+ "vand.vx v0, v6, %[kmask3]\n\t"
1443
+ "vand.vx v2, v7, %[kmask2]\n\t"
1444
+ "vsll.vi v6, v0, 4\n\t"
1445
+ "li %[t2], 8\n\t"
1446
+ "addi %[t1], %[utmp], 4\n\t"
1447
+ "vor.vv v1, v6, v2\n\t"
1448
+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
1449
+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
1450
+ "vsetivli zero, 8, e16, m1\n\t"
1451
+ "vle32.v v2, (%[bsums])\n\t"
1452
+ "vnsrl.wi v0, v2, 0\n\t"
1453
+ "vnsrl.wi v1, v2, 16\n\t"
1454
+ "vadd.vv v2, v0, v1\n\t"
1455
+ "vle8.v v3, (%[mins])\n\t"
1456
+ "vzext.vf2 v4, v3\n\t"
1457
+ "vwmul.vv v6, v4, v2\n\t"
1458
+ "vmv.v.x v0, zero\n\t"
1459
+ "vsetivli zero, 8, e32, m2\n\t"
1460
+ "vredsum.vs v0, v6, v0\n\t"
1461
+ "vmv.x.s %[sumi], v0"
1462
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1463
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1464
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1465
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1466
+ : "memory"
1467
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1468
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1469
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1470
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1471
+ );
1472
+ sumf -= dmin * sumi;
1473
+
1474
+ const uint8_t * restrict q4 = x[i].qs;
1475
+ const int8_t * restrict q8 = y[i].qs;
1476
+
1477
+ sumi = 0;
1478
+ const uint8_t * scale = scales;
1479
+
1480
+ for (int j = 0; j < QK_K/128; ++j) {
1481
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1482
+ __asm__ __volatile__(
1483
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1484
+ "vle8.v v8, (%[q8])\n\t"
1485
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1486
+ "vle8.v v0, (%[q4])\n\t"
1487
+ "vsrl.vi v4, v0, 4\n\t"
1488
+ "vand.vi v0, v0, 0xF\n\t"
1489
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1490
+ "vwmul.vv v28, v6, v14\n\t"
1491
+ "vwmul.vv v20, v4, v10\n\t"
1492
+ "vwmul.vv v24, v2, v12\n\t"
1493
+ "vwmul.vv v16, v0, v8\n\t"
1494
+ "vsetivli zero, 4, e32, m1\n\t"
1495
+ "vle8.v v2, (%[scale])\n\t"
1496
+ "vmv.v.x v0, zero\n\t"
1497
+ "vzext.vf4 v1, v2\n\t"
1498
+ "vsetvli zero, %[vl32], e16, m4\n\t"
1499
+ "vwredsum.vs v6, v24, v0\n\t"
1500
+ "vwredsum.vs v7, v28, v0\n\t"
1501
+ "vwredsum.vs v4, v16, v0\n\t"
1502
+ "vwredsum.vs v5, v20, v0\n\t"
1503
+ "vsetivli zero, 4, e32, m1\n\t"
1504
+ "vslideup.vi v6, v7, 1\n\t"
1505
+ "vslideup.vi v4, v5, 1\n\t"
1506
+ "vslideup.vi v4, v6, 2\n\t"
1507
+ "vmul.vv v8, v4, v1\n\t"
1508
+ "vredsum.vs v0, v8, v0\n\t"
1509
+ "vmv.x.s %[tmp], v0\n\t"
1510
+ "add %[sumi], %[sumi], %[tmp]"
1511
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1512
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1513
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1514
+ : "memory"
1515
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1516
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1517
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1518
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1519
+ );
1520
+
1521
+ q4 += 64; q8 += 128; scale += 4;
1522
+ }
1523
+
1524
+ sumf += d * sumi;
1525
+ }
1526
+ break;
1527
+ default:
1528
+ assert(false && "Unsupported vector length");
1529
+ break;
1530
+ }
1531
+
1532
+ *s = sumf;
1533
+
1534
+ #else
1535
+
1536
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1537
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1538
+
1539
+ int8_t aux8[QK_K];
1540
+ int16_t aux16[8];
1541
+ float sums [8];
1542
+ int32_t aux32[8];
1543
+ memset(sums, 0, 8*sizeof(float));
1544
+
1545
+ float sumf = 0;
1546
+ for (int i = 0; i < nb; ++i) {
1547
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1548
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1549
+ memset(aux32, 0, 8*sizeof(int32_t));
1550
+ int8_t * GGML_RESTRICT a = aux8;
1551
+ for (int j = 0; j < QK_K/64; ++j) {
1552
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1553
+ a += 32;
1554
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1555
+ a += 32; q4 += 32;
1556
+ }
1557
+ memcpy(utmp, x[i].scales, 12);
1558
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1559
+ const uint32_t uaux = utmp[1] & kmask1;
1560
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1561
+ utmp[2] = uaux;
1562
+ utmp[0] &= kmask1;
1563
+
1564
+ int sumi = 0;
1565
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1566
+ a = aux8;
1567
+ int is = 0;
1568
+ for (int j = 0; j < QK_K/32; ++j) {
1569
+ int32_t scale = scales[is++];
1570
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1571
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1572
+ q8 += 8; a += 8;
1573
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1574
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1575
+ q8 += 8; a += 8;
1576
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1577
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1578
+ q8 += 8; a += 8;
1579
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1580
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1581
+ q8 += 8; a += 8;
1582
+ }
1583
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1584
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1585
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1586
+ sumf -= dmin * sumi;
1587
+ }
1588
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1589
+ *s = sumf;
1590
+ #endif
1591
+ }
1592
+
1593
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1594
+ assert(n % QK_K == 0);
1595
+ assert(nrc == 1);
1596
+ UNUSED(nrc);
1597
+ UNUSED(bx);
1598
+ UNUSED(by);
1599
+ UNUSED(bs);
1600
+
1601
+ const block_q5_K * GGML_RESTRICT x = vx;
1602
+ const block_q8_K * GGML_RESTRICT y = vy;
1603
+
1604
+ const int nb = n / QK_K;
1605
+
1606
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1607
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1608
+ static const uint32_t kmask3 = 0x03030303;
1609
+
1610
+ uint32_t utmp[4];
1611
+
1612
+ #if defined __riscv_v
1613
+
1614
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1615
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1616
+
1617
+ float sumf = 0;
1618
+ float sums = 0.0;
1619
+
1620
+ size_t vl;
1621
+
1622
+ for (int i = 0; i < nb; ++i) {
1623
+
1624
+ vl = 8;
1625
+
1626
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1627
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1628
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1629
+
1630
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1631
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1632
+
1633
+ vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
1634
+ vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
1635
+ vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
1636
+
1637
+ memcpy(utmp, x[i].scales, 12);
1638
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1639
+ const uint32_t uaux = utmp[1] & kmask1;
1640
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1641
+ utmp[2] = uaux;
1642
+ utmp[0] &= kmask1;
1643
+
1644
+ vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
1645
+ vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1646
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
1647
+
1648
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1649
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1650
+
1651
+ vl = 32;
1652
+ int32_t aux32 = 0;
1653
+ int is = 0;
1654
+
1655
+ uint8_t m = 1;
1656
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1657
+ vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
1658
+
1659
+ for (int j = 0; j < QK_K/64; ++j) {
1660
+ // load Q5 and Q8
1661
+ vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
1662
+ vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
1663
+ vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
1664
+
1665
+ // compute mask for addition
1666
+ vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
1667
+ vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
1668
+ vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
1669
+ vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
1670
+ m <<= 1;
1671
+
1672
+ vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
1673
+ vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
1674
+ vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
1675
+ vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
1676
+ m <<= 1;
1677
+
1678
+ vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
1679
+ vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
1680
+
1681
+ vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
1682
+ vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
1683
+
1684
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
1685
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
1686
+
1687
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
1688
+ q5 += 32; q8 += 64;
1689
+
1690
+ }
1691
+
1692
+ sums += aux32 * d;
1693
+
1694
+ }
1695
+
1696
+ *s = sumf+sums;
1697
+
1698
+ #else
1699
+
1700
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1701
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1702
+
1703
+ int8_t aux8[QK_K];
1704
+ int16_t aux16[8];
1705
+ float sums [8];
1706
+ int32_t aux32[8];
1707
+ memset(sums, 0, 8*sizeof(float));
1708
+
1709
+ float sumf = 0;
1710
+ for (int i = 0; i < nb; ++i) {
1711
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1712
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1713
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1714
+ memset(aux32, 0, 8*sizeof(int32_t));
1715
+ int8_t * GGML_RESTRICT a = aux8;
1716
+ uint8_t m = 1;
1717
+ for (int j = 0; j < QK_K/64; ++j) {
1718
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1719
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1720
+ a += 32; m <<= 1;
1721
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1722
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1723
+ a += 32; m <<= 1;
1724
+ q4 += 32;
1725
+ }
1726
+ memcpy(utmp, x[i].scales, 12);
1727
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1728
+ const uint32_t uaux = utmp[1] & kmask1;
1729
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1730
+ utmp[2] = uaux;
1731
+ utmp[0] &= kmask1;
1732
+
1733
+ int sumi = 0;
1734
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1735
+ a = aux8;
1736
+ int is = 0;
1737
+ for (int j = 0; j < QK_K/32; ++j) {
1738
+ int32_t scale = scales[is++];
1739
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1740
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1741
+ q8 += 8; a += 8;
1742
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1743
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1744
+ q8 += 8; a += 8;
1745
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1746
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1747
+ q8 += 8; a += 8;
1748
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1749
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1750
+ q8 += 8; a += 8;
1751
+ }
1752
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1753
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1754
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1755
+ sumf -= dmin * sumi;
1756
+ }
1757
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1758
+ *s = sumf;
1759
+ #endif
1760
+ }
1761
+
1762
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1763
+ assert(n % QK_K == 0);
1764
+ assert(nrc == 1);
1765
+ UNUSED(nrc);
1766
+ UNUSED(bx);
1767
+ UNUSED(by);
1768
+ UNUSED(bs);
1769
+
1770
+ const block_q6_K * GGML_RESTRICT x = vx;
1771
+ const block_q8_K * GGML_RESTRICT y = vy;
1772
+
1773
+ const int nb = n / QK_K;
1774
+
1775
+ #if defined __riscv_xtheadvector
1776
+
1777
+ float sumf = 0;
1778
+
1779
+ for (int i = 0; i < nb; ++i) {
1780
+
1781
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1782
+
1783
+ const uint8_t * restrict q6 = x[i].ql;
1784
+ const uint8_t * restrict qh = x[i].qh;
1785
+ const int8_t * restrict q8 = y[i].qs;
1786
+
1787
+ const int8_t * restrict scale = x[i].scales;
1788
+
1789
+ int sum_t = 0;
1790
+ int t0;
1791
+
1792
+ for (int j = 0; j < QK_K/128; ++j) {
1793
+ __asm__ __volatile__(
1794
+ "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
1795
+ "th.vlb.v v4, (%[qh])\n\t"
1796
+ "th.vsll.vi v0, v4, 4\n\t"
1797
+ "th.vsll.vi v2, v4, 2\n\t"
1798
+ "th.vsrl.vi v6, v4, 2\n\t"
1799
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1800
+ "th.vlb.v v8, (%[q6])\n\t"
1801
+ "th.vsrl.vi v12, v8, 4\n\t"
1802
+ "th.vand.vi v8, v8, 0xF\n\t"
1803
+ "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
1804
+ "th.vand.vx v0, v0, %[mask]\n\t"
1805
+ "th.vor.vv v8, v8, v0\n\t"
1806
+ "th.vlb.v v0, (%[q8])\n\t"
1807
+ "th.vsub.vx v8, v8, %[vl32]\n\t"
1808
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1809
+ "th.vwmul.vv v16, v0, v8\n\t"
1810
+ "th.vwmul.vv v24, v4, v12\n\t"
1811
+ "li %[t0], 16\n\t"
1812
+ "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
1813
+ "th.vmv.v.x v0, zero\n\t"
1814
+ "th.vwredsum.vs v10, v16, v0\n\t"
1815
+ "th.vwredsum.vs v9, v18, v0\n\t"
1816
+ "th.vwredsum.vs v8, v20, v0\n\t"
1817
+ "th.vwredsum.vs v7, v22, v0\n\t"
1818
+ "th.vwredsum.vs v11, v24, v0\n\t"
1819
+ "th.vwredsum.vs v12, v26, v0\n\t"
1820
+ "th.vwredsum.vs v13, v28, v0\n\t"
1821
+ "th.vwredsum.vs v14, v30, v0\n\t"
1822
+ "li %[t0], 4\n\t"
1823
+ "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
1824
+ "th.vslideup.vi v10, v9, 1\n\t"
1825
+ "th.vslideup.vi v8, v7, 1\n\t"
1826
+ "th.vslideup.vi v11, v12, 1\n\t"
1827
+ "th.vslideup.vi v13, v14, 1\n\t"
1828
+ "th.vslideup.vi v10, v8, 2\n\t"
1829
+ "th.vslideup.vi v11, v13, 2\n\t"
1830
+ "li %[t0], 8\n\t"
1831
+ "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
1832
+ "th.vlb.v v4, (%[scale])\n\t"
1833
+ "th.vmul.vv v2, v4, v10\n\t"
1834
+ "th.vredsum.vs v0, v2, v0\n\t"
1835
+ "th.vmv.x.s %[t0], v0\n\t"
1836
+ "add %[sumi], %[sumi], %[t0]"
1837
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1838
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1839
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1840
+ , [mask] "r" (0x30)
1841
+ : "memory"
1842
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1843
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1844
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1845
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1846
+ );
1847
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
1848
+ }
1849
+
1850
+ sumf += d * sum_t;
1851
+
1852
+ }
1853
+
1854
+ *s = sumf;
1855
+
1856
+ #elif defined __riscv_v
1857
+
1858
+ float sumf = 0;
1859
+ const int vector_length = __riscv_vlenb() * 8;
1860
+
1861
+ switch (vector_length) {
1862
+ case 256:
1863
+ for (int i = 0; i < nb; ++i) {
1864
+
1865
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1866
+
1867
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
1868
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1869
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1870
+
1871
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
1872
+
1873
+ size_t vl;
1874
+
1875
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1876
+
1877
+ int sum_t = 0;
1878
+ int is = 0;
1879
+
1880
+ for (int j = 0; j < QK_K/128; ++j) {
1881
+
1882
+ vl = 32;
1883
+
1884
+ // load qh
1885
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
1886
+
1887
+ // load Q6
1888
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
1889
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
1890
+
1891
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
1892
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
1893
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
1894
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
1895
+
1896
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
1897
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
1898
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
1899
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
1900
+
1901
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
1902
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
1903
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
1904
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
1905
+
1906
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
1907
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
1908
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
1909
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
1910
+
1911
+ // load Q8 and take product
1912
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
1913
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
1914
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
1915
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
1916
+
1917
+ vl = 16;
1918
+
1919
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
1920
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
1921
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
1922
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
1923
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
1924
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
1925
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
1926
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
1927
+
1928
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
1929
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
1930
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
1931
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
1932
+
1933
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1934
+
1935
+ q6 += 64; qh += 32; q8 += 128; is=8;
1936
+
1937
+ }
1938
+
1939
+ sumf += d * sum_t;
1940
+
1941
+ }
1942
+ break;
1943
+ case 128:
1944
+ for (int i = 0; i < nb; ++i) {
1945
+
1946
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1947
+
1948
+ const uint8_t * restrict q6 = x[i].ql;
1949
+ const uint8_t * restrict qh = x[i].qh;
1950
+ const int8_t * restrict q8 = y[i].qs;
1951
+
1952
+ const int8_t * restrict scale = x[i].scales;
1953
+
1954
+ int sum_t = 0;
1955
+ int t0;
1956
+
1957
+ for (int j = 0; j < QK_K/128; ++j) {
1958
+ __asm__ __volatile__(
1959
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1960
+ "vle8.v v4, (%[qh])\n\t"
1961
+ "vsll.vi v0, v4, 4\n\t"
1962
+ "vsll.vi v2, v4, 2\n\t"
1963
+ "vsrl.vi v6, v4, 2\n\t"
1964
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1965
+ "vle8.v v8, (%[q6])\n\t"
1966
+ "vsrl.vi v12, v8, 4\n\t"
1967
+ "vand.vi v8, v8, 0xF\n\t"
1968
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1969
+ "vand.vx v0, v0, %[mask]\n\t"
1970
+ "vor.vv v8, v8, v0\n\t"
1971
+ "vle8.v v0, (%[q8])\n\t"
1972
+ "vsub.vx v8, v8, %[vl32]\n\t"
1973
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1974
+ "vwmul.vv v16, v0, v8\n\t"
1975
+ "vwmul.vv v24, v4, v12\n\t"
1976
+ "vsetivli zero, 16, e16, m2\n\t"
1977
+ "vmv.v.x v0, zero\n\t"
1978
+ "vwredsum.vs v10, v16, v0\n\t"
1979
+ "vwredsum.vs v9, v18, v0\n\t"
1980
+ "vwredsum.vs v8, v20, v0\n\t"
1981
+ "vwredsum.vs v7, v22, v0\n\t"
1982
+ "vwredsum.vs v11, v24, v0\n\t"
1983
+ "vwredsum.vs v12, v26, v0\n\t"
1984
+ "vwredsum.vs v13, v28, v0\n\t"
1985
+ "vwredsum.vs v14, v30, v0\n\t"
1986
+ "vsetivli zero, 4, e32, m1\n\t"
1987
+ "vslideup.vi v10, v9, 1\n\t"
1988
+ "vslideup.vi v8, v7, 1\n\t"
1989
+ "vslideup.vi v11, v12, 1\n\t"
1990
+ "vslideup.vi v13, v14, 1\n\t"
1991
+ "vslideup.vi v10, v8, 2\n\t"
1992
+ "vslideup.vi v11, v13, 2\n\t"
1993
+ "vsetivli zero, 8, e32, m2\n\t"
1994
+ "vle8.v v2, (%[scale])\n\t"
1995
+ "vsext.vf4 v4, v2\n\t"
1996
+ "vmul.vv v2, v4, v10\n\t"
1997
+ "vredsum.vs v0, v2, v0\n\t"
1998
+ "vmv.x.s %[t0], v0\n\t"
1999
+ "add %[sumi], %[sumi], %[t0]"
2000
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
2001
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
2002
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
2003
+ , [mask] "r" (0x30)
2004
+ : "memory"
2005
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
2006
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
2007
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
2008
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
2009
+ );
2010
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
2011
+ }
2012
+
2013
+ sumf += d * sum_t;
2014
+
2015
+ }
2016
+ break;
2017
+ default:
2018
+ assert(false && "Unsupported vector length");
2019
+ break;
2020
+ }
2021
+
2022
+ *s = sumf;
2023
+
2024
+ #else
2025
+
2026
+ int8_t aux8[QK_K];
2027
+ int16_t aux16[8];
2028
+ float sums [8];
2029
+ int32_t aux32[8];
2030
+ memset(sums, 0, 8*sizeof(float));
2031
+
2032
+ float sumf = 0;
2033
+ for (int i = 0; i < nb; ++i) {
2034
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2035
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
2036
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
2037
+ memset(aux32, 0, 8*sizeof(int32_t));
2038
+ int8_t * GGML_RESTRICT a = aux8;
2039
+ for (int j = 0; j < QK_K; j += 128) {
2040
+ for (int l = 0; l < 32; ++l) {
2041
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2042
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2043
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2044
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2045
+ }
2046
+ a += 128;
2047
+ q4 += 64;
2048
+ qh += 32;
2049
+ }
2050
+ a = aux8;
2051
+ int is = 0;
2052
+ for (int j = 0; j < QK_K/16; ++j) {
2053
+ int scale = x[i].scales[is++];
2054
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2055
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2056
+ q8 += 8; a += 8;
2057
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2058
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2059
+ q8 += 8; a += 8;
2060
+ }
2061
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2062
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2063
+ }
2064
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2065
+ *s = sumf;
2066
+ #endif
2067
+ }
2068
+
ggml/src/ggml-cpu/arch/riscv/repack.cpp ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_CPP
2
+ #define GGML_COMMON_DECL_CPP
3
+ #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
5
+
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+ #include "ggml-cpu-impl.h"
9
+ #include "traits.h"
10
+
11
+ #include <cmath>
12
+ #include <cstring>
13
+ #include <cassert>
14
+ #include <cstdlib> // for qsort
15
+ #include <cstdio> // for GGML_ASSERT
16
+
17
+ #define GGML_CPU_CLANG_WORKAROUND
18
+ #include "../../repack.h"
19
+
20
+ #if defined(__GNUC__)
21
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
22
+ #endif
23
+
24
+ #define UNUSED GGML_UNUSED
25
+
26
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
27
+ const int qk = QK8_0;
28
+ const int nb = n / qk;
29
+ const int ncols_interleaved = 8;
30
+ const int blocklen = 8;
31
+
32
+ assert (n % qk == 0);
33
+ assert (nc % ncols_interleaved == 0);
34
+
35
+ UNUSED(s);
36
+ UNUSED(bs);
37
+ UNUSED(vx);
38
+ UNUSED(vy);
39
+ UNUSED(nr);
40
+ UNUSED(nc);
41
+ UNUSED(nb);
42
+ UNUSED(ncols_interleaved);
43
+ UNUSED(blocklen);
44
+
45
+ #if defined __riscv_v
46
+ if (__riscv_vlenb() >= QK4_0) {
47
+ const size_t vl = QK4_0;
48
+
49
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
50
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
51
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
52
+
53
+ vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
54
+ for (int l = 0; l < nb; l++) {
55
+ const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
56
+ const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
57
+ const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
58
+ const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
59
+ __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment constraints
60
+ const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
61
+ const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
62
+ const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
63
+ const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
64
+
65
+ const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
66
+ const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
67
+ const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
68
+ const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
69
+ const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
70
+ const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
71
+ const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
72
+
73
+ const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
74
+ const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
75
+ const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
76
+ const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
77
+
78
+ const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
79
+ const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
80
+ const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
81
+ const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
82
+ const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
83
+ const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
84
+ const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
85
+ const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
86
+ const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
87
+ const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
88
+ const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
89
+ const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
90
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
91
+
92
+ // vector version needs Zvfhmin extension
93
+ const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
94
+ const float b_scales[8] = {
95
+ GGML_FP16_TO_FP32(b_ptr[l].d[0]),
96
+ GGML_FP16_TO_FP32(b_ptr[l].d[1]),
97
+ GGML_FP16_TO_FP32(b_ptr[l].d[2]),
98
+ GGML_FP16_TO_FP32(b_ptr[l].d[3]),
99
+ GGML_FP16_TO_FP32(b_ptr[l].d[4]),
100
+ GGML_FP16_TO_FP32(b_ptr[l].d[5]),
101
+ GGML_FP16_TO_FP32(b_ptr[l].d[6]),
102
+ GGML_FP16_TO_FP32(b_ptr[l].d[7])
103
+ };
104
+ const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
105
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
106
+ sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
107
+ }
108
+ __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
109
+ }
110
+ return;
111
+ }
112
+
113
+ #endif
114
+ {
115
+ float sumf[8];
116
+ int sumi;
117
+
118
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
119
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
120
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
121
+
122
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
123
+ for (int l = 0; l < nb; l++) {
124
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
125
+ for (int j = 0; j < ncols_interleaved; j++) {
126
+ sumi = 0;
127
+ for (int i = 0; i < blocklen; ++i) {
128
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
129
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
130
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
131
+ }
132
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
133
+ }
134
+ }
135
+ }
136
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
137
+ }
138
+ }
139
+ }
140
+
141
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
142
+ const int qk = QK8_0;
143
+ const int nb = n / qk;
144
+ const int ncols_interleaved = 8;
145
+ const int blocklen = 8;
146
+
147
+ assert (n % qk == 0);
148
+ assert (nr % 4 == 0);
149
+ assert (nc % ncols_interleaved == 0);
150
+
151
+ UNUSED(s);
152
+ UNUSED(bs);
153
+ UNUSED(vx);
154
+ UNUSED(vy);
155
+ UNUSED(nr);
156
+ UNUSED(nc);
157
+ UNUSED(nb);
158
+ UNUSED(ncols_interleaved);
159
+ UNUSED(blocklen);
160
+
161
+ #if defined __riscv_v
162
+ if (__riscv_vlenb() >= QK4_0) {
163
+ const size_t vl = QK4_0;
164
+
165
+ for (int y = 0; y < nr / 4; y++) {
166
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
167
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
168
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
169
+ vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
170
+ vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
171
+ vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
172
+ vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
173
+ for (int l = 0; l < nb; l++) {
174
+ const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
175
+ const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
176
+ const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
177
+ const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
178
+ const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
179
+ const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
180
+ const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
181
+
182
+ // vector version needs Zvfhmin extension
183
+ const float a_scales[4] = {
184
+ GGML_FP16_TO_FP32(a_ptr[l].d[0]),
185
+ GGML_FP16_TO_FP32(a_ptr[l].d[1]),
186
+ GGML_FP16_TO_FP32(a_ptr[l].d[2]),
187
+ GGML_FP16_TO_FP32(a_ptr[l].d[3])
188
+ };
189
+ const float b_scales[8] = {
190
+ GGML_FP16_TO_FP32(b_ptr[l].d[0]),
191
+ GGML_FP16_TO_FP32(b_ptr[l].d[1]),
192
+ GGML_FP16_TO_FP32(b_ptr[l].d[2]),
193
+ GGML_FP16_TO_FP32(b_ptr[l].d[3]),
194
+ GGML_FP16_TO_FP32(b_ptr[l].d[4]),
195
+ GGML_FP16_TO_FP32(b_ptr[l].d[5]),
196
+ GGML_FP16_TO_FP32(b_ptr[l].d[6]),
197
+ GGML_FP16_TO_FP32(b_ptr[l].d[7])
198
+ };
199
+ const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
200
+
201
+ const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
202
+ const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
203
+ const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
204
+ const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
205
+ __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
206
+ vint16m4_t sumi_l0;
207
+ {
208
+ const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
209
+ const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
210
+ const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
211
+ const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
212
+ const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
213
+ const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
214
+ const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
215
+ const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
216
+
217
+ sumi_l0 = sumi_hi_m;
218
+ }
219
+
220
+ {
221
+ const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
222
+ const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
223
+ const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
224
+ const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
225
+ const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
226
+ const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
227
+ const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
228
+ const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
229
+ const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
230
+ const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
231
+ const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
232
+ const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
233
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
234
+
235
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
236
+ sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
237
+ }
238
+
239
+ const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
240
+ const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
241
+ const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
242
+ const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
243
+ __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
244
+ vint16m4_t sumi_l1;
245
+ {
246
+ const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
247
+ const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
248
+ const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
249
+ const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
250
+ const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
251
+ const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
252
+ const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
253
+ const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
254
+
255
+ sumi_l1 = sumi_hi_m;
256
+ }
257
+
258
+ {
259
+ const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
260
+ const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
261
+ const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
262
+ const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
263
+ const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
264
+ const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
265
+ const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
266
+ const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
267
+ const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
268
+ const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
269
+ const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
270
+ const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
271
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
272
+
273
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
274
+ sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
275
+ }
276
+
277
+ const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
278
+ const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
279
+ const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
280
+ const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
281
+ __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
282
+ vint16m4_t sumi_l2;
283
+ {
284
+ const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
285
+ const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
286
+ const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
287
+ const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
288
+ const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
289
+ const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
290
+ const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
291
+ const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
292
+
293
+ sumi_l2 = sumi_hi_m;
294
+ }
295
+
296
+ {
297
+ const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
298
+ const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
299
+ const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
300
+ const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
301
+ const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
302
+ const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
303
+ const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
304
+ const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
305
+ const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
306
+ const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
307
+ const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
308
+ const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
309
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
310
+
311
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
312
+ sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
313
+ }
314
+
315
+ const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
316
+ const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
317
+ const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
318
+ const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
319
+ __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
320
+ vint16m4_t sumi_l3;
321
+ {
322
+ const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
323
+ const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
324
+ const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
325
+ const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
326
+ const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
327
+ const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
328
+ const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
329
+ const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
330
+
331
+ sumi_l3 = sumi_hi_m;
332
+ }
333
+
334
+ {
335
+ const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
336
+ const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
337
+ const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
338
+ const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
339
+ const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
340
+ const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
341
+ const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
342
+ const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
343
+ const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
344
+ const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
345
+ const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
346
+ const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
347
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
348
+
349
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
350
+ sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
351
+ }
352
+ }
353
+ __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
354
+ __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
355
+ __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
356
+ __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
357
+ }
358
+ }
359
+
360
+ return;
361
+ }
362
+
363
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
364
+ float sumf[4][8];
365
+ int sumi;
366
+
367
+ for (int y = 0; y < nr / 4; y++) {
368
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
369
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
370
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
371
+ for (int m = 0; m < 4; m++) {
372
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
373
+ }
374
+ for (int l = 0; l < nb; l++) {
375
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
376
+ for (int m = 0; m < 4; m++) {
377
+ for (int j = 0; j < ncols_interleaved; j++) {
378
+ sumi = 0;
379
+ for (int i = 0; i < blocklen; ++i) {
380
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
381
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
382
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
383
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
384
+ }
385
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
386
+ }
387
+ }
388
+ }
389
+ }
390
+ for (int m = 0; m < 4; m++) {
391
+ for (int j = 0; j < ncols_interleaved; j++)
392
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
393
+ }
394
+ }
395
+ }
396
+ }
ggml/src/ggml-cpu/arch/s390/quants.c ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+
7
+ #include "../../quants.h"
8
+ #include "../../ggml-cpu-impl.h"
9
+
10
+ #include <math.h>
11
+ #include <string.h>
12
+ #include <assert.h>
13
+ #include <float.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
26
+ assert(QK8_0 == 32);
27
+ assert(k % QK8_0 == 0);
28
+ const int nb = k / QK8_0;
29
+
30
+ block_q8_0 * GGML_RESTRICT y = vy;
31
+
32
+ #if defined(__VXE__) || defined(__VXE2__)
33
+ for (int i = 0; i < nb; i++) {
34
+ __vector float srcv [8];
35
+ __vector float asrcv[8];
36
+ __vector float amaxv[8];
37
+
38
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
39
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
40
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
41
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
42
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
43
+
44
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
45
+ vec_extract(amaxv[0], 1)),
46
+ MAX(vec_extract(amaxv[0], 2),
47
+ vec_extract(amaxv[0], 3)));
48
+
49
+ const float d = amax / ((1 << 7) - 1);
50
+ const float id = d ? 1.0f / d : 0.0f;
51
+
52
+ y[i].d = GGML_FP32_TO_FP16(d);
53
+
54
+ for (int j = 0; j < 8; j++) {
55
+ const __vector float v = vec_mul(srcv[j], vec_splats(id));
56
+ const __vector int32_t vi = vec_signed(v);
57
+
58
+ y[i].qs[4*j + 0] = vec_extract(vi, 0);
59
+ y[i].qs[4*j + 1] = vec_extract(vi, 1);
60
+ y[i].qs[4*j + 2] = vec_extract(vi, 2);
61
+ y[i].qs[4*j + 3] = vec_extract(vi, 3);
62
+ }
63
+ }
64
+ #else
65
+ GGML_UNUSED(nb);
66
+ // scalar
67
+ quantize_row_q8_0_ref(x, y, k);
68
+ #endif
69
+ }
70
+
71
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
72
+ assert(k % QK8_1 == 0);
73
+ const int nb = k / QK8_1;
74
+
75
+ block_q8_1 * GGML_RESTRICT y = vy;
76
+
77
+ #if defined(__VXE__) || defined(__VXE2__)
78
+ for (int i = 0; i < nb; i++) {
79
+ __vector float srcv [8];
80
+ __vector float asrcv[8];
81
+ __vector float amaxv[8];
82
+
83
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
84
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
85
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
86
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
87
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
88
+
89
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
90
+ vec_extract(amaxv[0], 1)),
91
+ MAX(vec_extract(amaxv[0], 2),
92
+ vec_extract(amaxv[0], 3)));
93
+
94
+ const float d = amax / ((1 << 7) - 1);
95
+ const float id = d ? 1.0f / d : 0.0f;
96
+
97
+ y[i].d = GGML_FP32_TO_FP16(d);
98
+
99
+ __vector int32_t acc = vec_splats(0);
100
+
101
+ for (int j = 0; j < 8; j++) {
102
+ const __vector float v = vec_mul(srcv[j], vec_splats(id));
103
+ const __vector int32_t vi = vec_signed(v);
104
+
105
+ y[i].qs[4*j + 0] = vec_extract(vi, 0);
106
+ y[i].qs[4*j + 1] = vec_extract(vi, 1);
107
+ y[i].qs[4*j + 2] = vec_extract(vi, 2);
108
+ y[i].qs[4*j + 3] = vec_extract(vi, 3);
109
+
110
+ acc = vec_add(acc, vi);
111
+ }
112
+
113
+ y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
114
+ }
115
+ #else
116
+ GGML_UNUSED(nb);
117
+ // scalar
118
+ quantize_row_q8_1_ref(x, y, k);
119
+ #endif
120
+ }
121
+
122
+
123
+ //===================================== Dot products =================================
124
+
125
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
126
+ const int qk = QK8_0;
127
+ const int nb = n / qk;
128
+
129
+ assert(n % qk == 0);
130
+ assert(nrc == 1);
131
+ UNUSED(nrc);
132
+ UNUSED(bx);
133
+ UNUSED(by);
134
+ UNUSED(bs);
135
+
136
+ const block_q4_0 * GGML_RESTRICT x = vx;
137
+ const block_q8_0 * GGML_RESTRICT y = vy;
138
+
139
+ int ib = 0;
140
+ float sumf = 0;
141
+
142
+ #if defined(__VXE__) || defined(__VXE2__)
143
+ __vector float acc = vec_splats(0.0f);
144
+
145
+ const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
146
+ const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
147
+
148
+ for (; ib < nb; ++ib) {
149
+ const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
150
+ const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
151
+ const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
152
+
153
+ const __vector int8_t v_xls = vec_sub(v_xl, v_s);
154
+ const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
155
+
156
+ const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
157
+ const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
158
+
159
+ const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
160
+ const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
161
+ const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
162
+ const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
163
+
164
+ __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
165
+
166
+ const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
167
+ const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
168
+
169
+ acc = vec_madd(v_xy, v_d, acc);
170
+ }
171
+
172
+ sumf = acc[0] + acc[1] + acc[2] + acc[3];
173
+
174
+ #endif
175
+ for (; ib < nb; ++ib) {
176
+ int sumi0 = 0;
177
+ int sumi1 = 0;
178
+
179
+ for (int j = 0; j < qk/2; ++j) {
180
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
181
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
182
+
183
+ sumi0 += (v0 * y[ib].qs[j]);
184
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
185
+ }
186
+
187
+ int sumi = sumi0 + sumi1;
188
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
189
+ }
190
+
191
+ *s = sumf;
192
+ }
193
+
194
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
195
+ const int qk = QK8_1;
196
+ const int nb = n / qk;
197
+
198
+ assert(n % qk == 0);
199
+ assert(nrc == 1);
200
+ UNUSED(nrc);
201
+ UNUSED(bx);
202
+ UNUSED(by);
203
+ UNUSED(bs);
204
+
205
+ const block_q4_1 * GGML_RESTRICT x = vx;
206
+ const block_q8_1 * GGML_RESTRICT y = vy;
207
+
208
+ int ib = 0;
209
+ float sumf = 0;
210
+
211
+ #if defined(__VXE__) || defined(__VXE2__)
212
+ float summs = 0;
213
+ float32x4_t acc = vec_splats(0.0f);
214
+
215
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
216
+
217
+ #pragma GCC unroll 4
218
+ for (; ib < nb; ++ib) {
219
+ __builtin_prefetch(x[ib].qs, 0, 1);
220
+ __builtin_prefetch(y[ib].qs, 0, 1);
221
+
222
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
223
+
224
+ const uint8x16_t v_x = vec_xl(0, x[ib].qs);
225
+ const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
226
+ const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
227
+
228
+ const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
229
+ const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
230
+
231
+ const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
232
+ const float32x4_t v_xy = vec_float(v_xy_);
233
+
234
+ const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
235
+
236
+ acc = vec_madd(v_xy, v_d, acc);
237
+ }
238
+
239
+ sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
240
+
241
+ #endif
242
+ for (; ib < nb; ++ib) {
243
+ int sumi0 = 0;
244
+ int sumi1 = 0;
245
+
246
+ for (int j = 0; j < qk/2; ++j) {
247
+ const int v0 = (x[ib].qs[j] & 0x0F);
248
+ const int v1 = (x[ib].qs[j] >> 4);
249
+
250
+ sumi0 += (v0 * y[ib].qs[j]);
251
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
252
+ }
253
+
254
+ int sumi = sumi0 + sumi1;
255
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
256
+ }
257
+
258
+ *s = sumf;
259
+ }
260
+
261
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
262
+ const int qk = QK8_0;
263
+ const int nb = n / qk;
264
+
265
+ assert(n % qk == 0);
266
+ assert(nrc == 1);
267
+ UNUSED(nrc);
268
+ UNUSED(bx);
269
+ UNUSED(by);
270
+ UNUSED(bs);
271
+
272
+ const block_q8_0 * GGML_RESTRICT x = vx;
273
+ const block_q8_0 * GGML_RESTRICT y = vy;
274
+
275
+ int ib = 0;
276
+ float sumf = 0;
277
+
278
+ #if defined(__VXE__) || defined(__VXE2__)
279
+ __vector float acc = vec_splats(0.0f);
280
+
281
+ #pragma GCC unroll 8
282
+ for (; ib < nb; ++ib) {
283
+ __builtin_prefetch(x[ib].qs, 0, 1);
284
+ __builtin_prefetch(y[ib].qs, 0, 1);
285
+
286
+ const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
287
+ const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
288
+ const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
289
+ const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
290
+
291
+ const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
292
+ const float32x4_t v_xy = vec_float(v_xy_);
293
+ const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
294
+
295
+ acc = vec_madd(v_xy, v_d, acc);
296
+ }
297
+
298
+ sumf = acc[0] + acc[1] + acc[2] + acc[3];
299
+
300
+ #endif
301
+ for (; ib < nb; ++ib) {
302
+ int sumi = 0;
303
+
304
+ for (int j = 0; j < qk; j++) {
305
+ sumi += x[ib].qs[j]*y[ib].qs[j];
306
+ }
307
+
308
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
309
+ }
310
+
311
+ *s = sumf;
312
+ }
313
+
314
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
315
+ assert(n % QK_K == 0);
316
+ assert(nrc == 1);
317
+ UNUSED(nrc);
318
+ UNUSED(bx);
319
+ UNUSED(by);
320
+ UNUSED(bs);
321
+
322
+ const uint32_t kmask1 = 0x03030303;
323
+ const uint32_t kmask2 = 0x0f0f0f0f;
324
+
325
+ const block_q3_K * GGML_RESTRICT x = vx;
326
+ const block_q8_K * GGML_RESTRICT y = vy;
327
+
328
+ const int nb = n / QK_K;
329
+
330
+ #if defined(__VXE__) || defined(__VXE2__)
331
+ uint32_t aux[3];
332
+ uint32_t utmp[4];
333
+
334
+ const int32x4_t v_z = vec_splat_s32(0);
335
+ const uint8x16_t v_3m = vec_splat_u8(0x03);
336
+
337
+ const uint8x16_t v_0c = vec_splat_u8(1);
338
+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
339
+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
340
+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
341
+
342
+ uint8x16_t q3h[4];
343
+ uint8x16_t q3b[2];
344
+ int8x16_t q3bytes[4];
345
+ int8x16_t q8bytes[4];
346
+ uint8x16_t qhbits[2];
347
+
348
+ float sum = 0;
349
+
350
+ for (int i = 0; i < nb; ++i) {
351
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
352
+
353
+ const uint8_t * restrict x0l = x[i].qs;
354
+ const uint8_t * restrict x0h = x[i].hmask;
355
+ const int8_t * restrict y0 = y[i].qs;
356
+
357
+ qhbits[0] = vec_xl(0 , x0h);
358
+ qhbits[1] = vec_xl(16, x0h);
359
+
360
+ int32_t isum = 0;
361
+
362
+ memcpy(aux, x[i].scales, 12);
363
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
364
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
365
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
366
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
367
+
368
+ int8_t * scale = (int8_t *)utmp;
369
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
370
+
371
+ for (int j = 0; j < QK_K/128; ++j) {
372
+ int32x4_t isum0, isum1, isum2, isum3;
373
+
374
+ q3b[0] = vec_xl(0 , x0l);
375
+ q3b[1] = vec_xl(16, x0l);
376
+ x0l += 32;
377
+
378
+ q8bytes[0] = vec_xl(0 , y0);
379
+ q8bytes[1] = vec_xl(16 , y0);
380
+ q8bytes[2] = vec_xl(32 , y0);
381
+ q8bytes[3] = vec_xl(48 , y0);
382
+ q8bytes[4] = vec_xl(64 , y0);
383
+ q8bytes[5] = vec_xl(80 , y0);
384
+ q8bytes[6] = vec_xl(96 , y0);
385
+ q8bytes[7] = vec_xl(112, y0);
386
+ y0 += 128;
387
+
388
+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
389
+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
390
+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
391
+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
392
+
393
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
394
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
395
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
396
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
397
+
398
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
399
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
400
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
401
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
402
+
403
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
404
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
405
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
406
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
407
+
408
+ scale += 4;
409
+
410
+ q3h[0] = vec_andc(v_2c, qhbits[0]);
411
+ q3h[1] = vec_andc(v_2c, qhbits[1]);
412
+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
413
+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
414
+
415
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
416
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
417
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
418
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
419
+
420
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
421
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
422
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
423
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
424
+
425
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
426
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
427
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
428
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
429
+
430
+ scale += 4;
431
+
432
+ if (j == 0) {
433
+ qhbits[0] = vec_sr(qhbits[0], 4);
434
+ qhbits[1] = vec_sr(qhbits[1], 4);
435
+ }
436
+ }
437
+
438
+ sum += d * isum;
439
+ }
440
+
441
+ *s = sum;
442
+
443
+ #else
444
+ // scalar version
445
+ // This function is written like this so the compiler can manage to vectorize most of it
446
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
447
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
448
+ // The ideal situation would be if we could just write the code once, and the compiler would
449
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
450
+ // write vectorized versions for AVX, ARM_NEON, etc.
451
+
452
+ int8_t aux8[QK_K];
453
+ int16_t aux16[8];
454
+ float sums [8];
455
+ int32_t aux32[8];
456
+ memset(sums, 0, 8*sizeof(float));
457
+
458
+ uint32_t auxs[4];
459
+ const int8_t * scales = (const int8_t*)auxs;
460
+
461
+ float sumf = 0;
462
+ for (int i = 0; i < nb; ++i) {
463
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
464
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
465
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
466
+ memset(aux32, 0, 8*sizeof(int32_t));
467
+ int8_t * GGML_RESTRICT a = aux8;
468
+ uint8_t m = 1;
469
+ for (int j = 0; j < QK_K; j += 128) {
470
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
471
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
472
+ a += 32; m <<= 1;
473
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
474
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
475
+ a += 32; m <<= 1;
476
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
477
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
478
+ a += 32; m <<= 1;
479
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
480
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
481
+ a += 32; m <<= 1;
482
+ q3 += 32;
483
+ }
484
+ a = aux8;
485
+
486
+ memcpy(auxs, x[i].scales, 12);
487
+ uint32_t tmp = auxs[2];
488
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
489
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
490
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
491
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
492
+ for (int j = 0; j < QK_K/16; ++j) {
493
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
494
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
495
+ q8 += 8; a += 8;
496
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
497
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
498
+ q8 += 8; a += 8;
499
+ }
500
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
501
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
502
+ }
503
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
504
+ *s = sumf;
505
+
506
+ #endif
507
+
508
+ }
509
+
510
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
511
+ assert(n % QK_K == 0);
512
+ assert(nrc == 1);
513
+ UNUSED(nrc);
514
+ UNUSED(bx);
515
+ UNUSED(by);
516
+ UNUSED(bs);
517
+
518
+ const block_q4_K * GGML_RESTRICT x = vx;
519
+ const block_q8_K * GGML_RESTRICT y = vy;
520
+
521
+ const int nb = n / QK_K;
522
+
523
+ static const uint32_t kmask1 = 0x3f3f3f3f;
524
+ static const uint32_t kmask2 = 0x0f0f0f0f;
525
+ static const uint32_t kmask3 = 0x03030303;
526
+
527
+ uint32_t utmp[4];
528
+
529
+ #if defined(__VXE__) || defined(__VXE2__)
530
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
531
+ const int32x4_t v_z = vec_splat_s32(0);
532
+
533
+ uint8x16_t v_x[2];
534
+ int8x16_t v_xl[2];
535
+ int8x16_t v_y[2];
536
+
537
+ float sumf = 0;
538
+
539
+ for (int i = 0; i < nb; ++i) {
540
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
541
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
542
+
543
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
544
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
545
+ const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
546
+
547
+ memcpy(utmp, x[i].scales, 12);
548
+
549
+ uint32x4_t v_mins8 = { 0 };
550
+ v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
551
+ v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
552
+
553
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
554
+ utmp[0] &= kmask1;
555
+
556
+ const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
557
+
558
+ const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
559
+ const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
560
+ const int32x4_t v_mins = v_minso + v_minse;
561
+ sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
562
+
563
+ const uint8_t * scales = (const uint8_t *)utmp;
564
+ const uint8_t * GGML_RESTRICT x0 = x[i].qs;
565
+ const int8_t * GGML_RESTRICT y0 = y[i].qs;
566
+
567
+ int32_t sumi1 = 0;
568
+ int32_t sumi2 = 0;
569
+
570
+ for (int j = 0; j < QK_K/64; ++j) {
571
+ v_x[0] = vec_xl(0 , x0);
572
+ v_x[1] = vec_xl(16, x0);
573
+ x0 += 32;
574
+
575
+ v_y[0] = vec_xl(0 , y0);
576
+ v_y[1] = vec_xl(16, y0);
577
+ y0 += 32;
578
+
579
+ v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
580
+ v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
581
+
582
+ const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
583
+ sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
584
+
585
+ v_y[0] = vec_xl(0 , y0);
586
+ v_y[1] = vec_xl(16, y0);
587
+ y0 += 32;
588
+
589
+ v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
590
+ v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
591
+
592
+ const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
593
+ sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
594
+ }
595
+
596
+ sumf += d * (sumi1 + sumi2);
597
+ }
598
+
599
+ *s = sumf;
600
+
601
+ #else
602
+
603
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
604
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
605
+
606
+ int8_t aux8[QK_K];
607
+ int16_t aux16[8];
608
+ float sums [8];
609
+ int32_t aux32[8];
610
+ memset(sums, 0, 8*sizeof(float));
611
+
612
+ float sumf = 0;
613
+ for (int i = 0; i < nb; ++i) {
614
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
615
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
616
+ memset(aux32, 0, 8*sizeof(int32_t));
617
+ int8_t * GGML_RESTRICT a = aux8;
618
+ for (int j = 0; j < QK_K/64; ++j) {
619
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
620
+ a += 32;
621
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
622
+ a += 32; q4 += 32;
623
+ }
624
+ memcpy(utmp, x[i].scales, 12);
625
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
626
+ const uint32_t uaux = utmp[1] & kmask1;
627
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
628
+ utmp[2] = uaux;
629
+ utmp[0] &= kmask1;
630
+
631
+ int sumi = 0;
632
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
633
+ a = aux8;
634
+ int is = 0;
635
+ for (int j = 0; j < QK_K/32; ++j) {
636
+ int32_t scale = scales[is++];
637
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
638
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
639
+ q8 += 8; a += 8;
640
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
641
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
642
+ q8 += 8; a += 8;
643
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
644
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
645
+ q8 += 8; a += 8;
646
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
647
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
648
+ q8 += 8; a += 8;
649
+ }
650
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
651
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
652
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
653
+ sumf -= dmin * sumi;
654
+ }
655
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
656
+ *s = sumf;
657
+ #endif
658
+ }
659
+
660
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
661
+ assert(n % QK_K == 0);
662
+ assert(nrc == 1);
663
+ UNUSED(nrc);
664
+ UNUSED(bx);
665
+ UNUSED(by);
666
+ UNUSED(bs);
667
+
668
+ const block_q5_K * GGML_RESTRICT x = vx;
669
+ const block_q8_K * GGML_RESTRICT y = vy;
670
+
671
+ const int nb = n / QK_K;
672
+
673
+ static const uint32_t kmask1 = 0x3f3f3f3f;
674
+ static const uint32_t kmask2 = 0x0f0f0f0f;
675
+ static const uint32_t kmask3 = 0x03030303;
676
+
677
+ uint32_t utmp[4];
678
+
679
+ #if defined(__VXE__) || defined(__VXE2__)
680
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
681
+ const uint8x16_t v_1m = vec_splat_u8(0x01);
682
+ const uint8x16_t v_2m = vec_splat_u8(0x02);
683
+
684
+ const int32x4_t v_z = vec_splat_s32(0);
685
+
686
+ const uchar8x16_t v_minsm = {
687
+ 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
688
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
689
+ };
690
+
691
+ int8x16_t q5b[4];
692
+ uint8x16_t q5h[4];
693
+
694
+ uint8x16_t v_xl[2];
695
+ uint8x16_t v_xh[2];
696
+ int8x16_t v_y[4];
697
+
698
+ float sumf = 0;
699
+
700
+ for (int i = 0; i < nb; ++i) {
701
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
702
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
703
+
704
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
705
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
706
+ const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
707
+
708
+ memcpy(utmp, x[i].scales, 12);
709
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
710
+ const uint32_t uaux = utmp[1] & kmask1;
711
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
712
+ utmp[2] = uaux;
713
+ utmp[0] &= kmask1;
714
+
715
+ const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
716
+ const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
717
+ const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
718
+
719
+ const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
720
+ const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
721
+ const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
722
+ const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
723
+
724
+ const uint8_t * scales = (const uint8_t *)utmp;
725
+ const uint8_t * GGML_RESTRICT x0l = x[i].qs;
726
+ const uint8_t * GGML_RESTRICT x0h = x[i].qh;
727
+ const int8_t * GGML_RESTRICT y0 = y[i].qs;
728
+
729
+ v_xh[0] = vec_xl(0 , x0h);
730
+ v_xh[1] = vec_xl(16, x0h);
731
+
732
+ int32_t sumi = 0;
733
+ for (int j = 0; j < QK_K/64; ++j) {
734
+ v_xl[0] = vec_xl(0 , x0l);
735
+ v_xl[1] = vec_xl(16, x0l);
736
+ x0l += 32;
737
+
738
+ v_y[0] = vec_xl(0 , y0);
739
+ v_y[1] = vec_xl(16, y0);
740
+ v_y[2] = vec_xl(32, y0);
741
+ v_y[3] = vec_xl(48, y0);
742
+ y0 += 64;
743
+
744
+ q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
745
+ q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
746
+ q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
747
+ q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
748
+ v_xh[0] = vec_sr(v_xh[0], 2);
749
+ v_xh[1] = vec_sr(v_xh[1], 2);
750
+
751
+ q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
752
+ q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
753
+ q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
754
+ q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
755
+
756
+ int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
757
+ int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
758
+
759
+ sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
760
+ sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
761
+ }
762
+
763
+ sumf += d * sumi - dmin * mins;
764
+ }
765
+
766
+ *s = sumf;
767
+
768
+ #else
769
+
770
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
771
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
772
+
773
+ int8_t aux8[QK_K];
774
+ int16_t aux16[8];
775
+ float sums [8];
776
+ int32_t aux32[8];
777
+ memset(sums, 0, 8*sizeof(float));
778
+
779
+ float sumf = 0;
780
+ for (int i = 0; i < nb; ++i) {
781
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
782
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
783
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
784
+ memset(aux32, 0, 8*sizeof(int32_t));
785
+ int8_t * GGML_RESTRICT a = aux8;
786
+ uint8_t m = 1;
787
+ for (int j = 0; j < QK_K/64; ++j) {
788
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
789
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
790
+ a += 32; m <<= 1;
791
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
792
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
793
+ a += 32; m <<= 1;
794
+ q4 += 32;
795
+ }
796
+ memcpy(utmp, x[i].scales, 12);
797
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
798
+ const uint32_t uaux = utmp[1] & kmask1;
799
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
800
+ utmp[2] = uaux;
801
+ utmp[0] &= kmask1;
802
+
803
+ int sumi = 0;
804
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
805
+ a = aux8;
806
+ int is = 0;
807
+ for (int j = 0; j < QK_K/32; ++j) {
808
+ int32_t scale = scales[is++];
809
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
810
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
811
+ q8 += 8; a += 8;
812
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
813
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
814
+ q8 += 8; a += 8;
815
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
816
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
817
+ q8 += 8; a += 8;
818
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
819
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
820
+ q8 += 8; a += 8;
821
+ }
822
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
823
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
824
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
825
+ sumf -= dmin * sumi;
826
+ }
827
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
828
+ *s = sumf;
829
+ #endif
830
+ }
831
+
832
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
833
+ assert(n % QK_K == 0);
834
+ assert(nrc == 1);
835
+ UNUSED(nrc);
836
+ UNUSED(bx);
837
+ UNUSED(by);
838
+ UNUSED(bs);
839
+
840
+ const block_q6_K * GGML_RESTRICT x = vx;
841
+ const block_q8_K * GGML_RESTRICT y = vy;
842
+
843
+ const int nb = n / QK_K;
844
+
845
+ #if defined(__VXE__) || defined(__VXE2__)
846
+ float sum = 0;
847
+
848
+ // Lower 4-bit and upper 2-bit masks
849
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
850
+ const uint8x16_t v_um = vec_splat_u8(0x03);
851
+
852
+ const int32x4_t v_z = vec_splat_s32(0);
853
+
854
+ int8x16_t q6b[4];
855
+ uint8x16_t q6h[4];
856
+
857
+ uint8x16_t v_xl[4];
858
+ uint8x16_t v_xh[2];
859
+ int8x16_t v_y[4];
860
+
861
+ for (int i = 0; i < nb; ++i) {
862
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
863
+
864
+ const uint8_t * GGML_RESTRICT x0l = x[i].ql;
865
+ const uint8_t * GGML_RESTRICT x0h = x[i].qh;
866
+ const int8_t * GGML_RESTRICT y0 = y[i].qs;
867
+
868
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
869
+
870
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
871
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
872
+
873
+ const int8x16_t v_scale = vec_xl(0, scale);
874
+ const int16x8_t v_scalel = vec_unpackh(v_scale);
875
+ const int16x8_t v_scaleh = vec_unpackl(v_scale);
876
+
877
+ const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
878
+ const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
879
+ const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
880
+ const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
881
+ const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
882
+
883
+ const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
884
+
885
+ int32_t isum = 0;
886
+ for (int j = 0; j < QK_K/128; ++j) {
887
+ // Load model upper 2 bits
888
+ v_xh[0] = vec_xl(0 , x0h);
889
+ v_xh[1] = vec_xl(16, x0h);
890
+ x0h += 32;
891
+
892
+ // Load model lower 4 bits
893
+ v_xl[0] = vec_xl(0 , x0l);
894
+ v_xl[1] = vec_xl(16, x0l);
895
+ v_xl[2] = vec_xl(32, x0l);
896
+ v_xl[3] = vec_xl(48, x0l);
897
+ x0l += 64;
898
+
899
+ // Load activation quants
900
+ v_y[0] = vec_xl(0 , y0);
901
+ v_y[1] = vec_xl(16, y0);
902
+ v_y[2] = vec_xl(32, y0);
903
+ v_y[3] = vec_xl(48, y0);
904
+ y0 += 64;
905
+
906
+ q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
907
+ q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
908
+ uint8x16_t shifted = vec_sr(v_xh[0], 2);
909
+ q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
910
+ shifted = vec_sr(v_xh[1], 2);
911
+ q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
912
+
913
+ q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
914
+ q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
915
+ q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
916
+ q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
917
+
918
+ int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
919
+ int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
920
+ int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
921
+ int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
922
+
923
+ isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
924
+ (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
925
+ (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
926
+ (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
927
+
928
+ scale += 4;
929
+
930
+
931
+ // Load activation quants
932
+ v_y[0] = vec_xl(0 , y0);
933
+ v_y[1] = vec_xl(16, y0);
934
+ v_y[2] = vec_xl(32, y0);
935
+ v_y[3] = vec_xl(48, y0);
936
+ y0 += 64;
937
+
938
+ shifted = vec_sr(v_xh[0], 4);
939
+ q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
940
+ shifted = vec_sr(v_xh[1], 4);
941
+ q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
942
+ shifted = vec_sr(v_xh[0], 6);
943
+ q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
944
+ shifted = vec_sr(v_xh[1], 6);
945
+ q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
946
+
947
+ q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
948
+ q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
949
+ q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
950
+ q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
951
+
952
+ summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
953
+ summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
954
+ summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
955
+ summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
956
+
957
+ isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
958
+ (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
959
+ (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
960
+ (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
961
+
962
+ scale += 4;
963
+ }
964
+
965
+ sum += d_all * y[i].d * (isum - 32 * mins);
966
+ }
967
+
968
+ *s = sum;
969
+
970
+ #else
971
+
972
+ int8_t aux8[QK_K];
973
+ int16_t aux16[8];
974
+ float sums [8];
975
+ int32_t aux32[8];
976
+ memset(sums, 0, 8*sizeof(float));
977
+
978
+ float sumf = 0;
979
+ for (int i = 0; i < nb; ++i) {
980
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
981
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
982
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
983
+ memset(aux32, 0, 8*sizeof(int32_t));
984
+ int8_t * GGML_RESTRICT a = aux8;
985
+ for (int j = 0; j < QK_K; j += 128) {
986
+ for (int l = 0; l < 32; ++l) {
987
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
988
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
989
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
990
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
991
+ }
992
+ a += 128;
993
+ q4 += 64;
994
+ qh += 32;
995
+ }
996
+ a = aux8;
997
+ int is = 0;
998
+ for (int j = 0; j < QK_K/16; ++j) {
999
+ int scale = x[i].scales[is++];
1000
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1001
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1002
+ q8 += 8; a += 8;
1003
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1004
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1005
+ q8 += 8; a += 8;
1006
+ }
1007
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1008
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1009
+ }
1010
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1011
+ *s = sumf;
1012
+ #endif
1013
+ }
1014
+
1015
+ // #if defined(__VXE__) || defined(__VXE2__)
1016
+ // static const int8_t keven_signs_q2xs[1024] = {
1017
+ // 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1018
+ // 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1019
+ // 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1020
+ // 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1021
+ // 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1022
+ // 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1023
+ // 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1024
+ // 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1025
+ // 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1026
+ // 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1027
+ // 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1028
+ // 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1029
+ // 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1030
+ // 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1031
+ // 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1032
+ // 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1033
+ // 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1034
+ // 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1035
+ // 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1036
+ // 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1037
+ // 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1038
+ // 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1039
+ // 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1040
+ // 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1041
+ // 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1042
+ // 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1043
+ // 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1044
+ // 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1045
+ // 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1046
+ // 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1047
+ // 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1048
+ // 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
1049
+ // };
1050
+ // #endif
1051
+
1052
+ // void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1053
+ // assert(n % QK_K == 0);
1054
+ // assert(nrc == 1);
1055
+ // UNUSED(nrc);
1056
+ // UNUSED(bx);
1057
+ // UNUSED(by);
1058
+ // UNUSED(bs);
1059
+
1060
+ // const block_iq2_xxs * GGML_RESTRICT x = vx;
1061
+ // const block_q8_K * GGML_RESTRICT y = vy;
1062
+
1063
+ // const int nb = n / QK_K;
1064
+
1065
+ // #if defined(__VXE__) || defined(__VXE2__)
1066
+ // const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1067
+
1068
+ // uint32_t aux32[4];
1069
+ // const uint8_t * aux8 = (const uint8_t *)aux32;
1070
+
1071
+ // float sumf = 0;
1072
+
1073
+ // for (int i = 0; i < nb; ++i) {
1074
+ // const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1075
+ // const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1076
+ // const int8_t * GGML_RESTRICT q8 = y[i].qs;
1077
+
1078
+ // float sumf1 = 0, sumf2 = 0;
1079
+
1080
+ // for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
1081
+ // int8x16_t q8b0 = vec_xl( 0, q8);
1082
+ // int8x16_t qb81 = vec_xl(16, q8);
1083
+ // int8x16_t q8b2 = vec_xl(32, q8);
1084
+ // int8x16_t q8b3 = vec_xl(48, q8);
1085
+ // q8 += 64;
1086
+
1087
+ // memcpy(aux32, q2, 4 * sizeof(uint32_t));
1088
+ // q2 += 8;
1089
+
1090
+ // int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
1091
+ // int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
1092
+ // int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
1093
+ // int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
1094
+
1095
+ // int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
1096
+ // int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
1097
+ // int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
1098
+ // int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
1099
+
1100
+ // q2u0 = vec_mul(q2u0, q2s0);
1101
+ // q2u1 = vec_mul(q2u1, q2s1);
1102
+ // q2u2 = vec_mul(q2u2, q2s2);
1103
+ // q2u3 = vec_mul(q2u3, q2s3);
1104
+
1105
+ // const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
1106
+ // const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
1107
+
1108
+ // sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
1109
+ // sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
1110
+ // }
1111
+
1112
+ // sumf += d * (sumf1 + sumf2);
1113
+ // }
1114
+
1115
+ // *s = 0.25f * sumf;
1116
+
1117
+ // #else
1118
+
1119
+ // uint32_t aux32[2];
1120
+ // const uint8_t * aux8 = (const uint8_t *)aux32;
1121
+
1122
+ // float sumf = 0.f;
1123
+ // for (int i = 0; i < nb; ++i) {
1124
+ // const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1125
+ // const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1126
+ // const int8_t * GGML_RESTRICT q8 = y[i].qs;
1127
+ // int32_t bsum = 0;
1128
+ // for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
1129
+ // memcpy(aux32, q2, 2*sizeof(uint32_t));
1130
+ // q2 += 4;
1131
+ // const uint32_t ls = 2*(aux32[1] >> 28) + 1;
1132
+ // int32_t sumi = 0;
1133
+ // for (int l = 0; l < 4; ++l) {
1134
+ // const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
1135
+ // const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
1136
+ // for (int j = 0; j < 8; ++j) {
1137
+ // sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
1138
+ // }
1139
+ // q8 += 8;
1140
+ // }
1141
+ // bsum += sumi * ls;
1142
+ // }
1143
+ // sumf += d * bsum;
1144
+ // }
1145
+ // *s = 0.125f * sumf;
1146
+ // #endif
1147
+ // }
1148
+
1149
+ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1150
+ assert(nrc == 1);
1151
+ UNUSED(nrc);
1152
+ UNUSED(bx);
1153
+ UNUSED(by);
1154
+ UNUSED(bs);
1155
+ assert(n % QK4_NL == 0);
1156
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
1157
+
1158
+ const block_iq4_nl * GGML_RESTRICT x = vx;
1159
+ const block_q8_0 * GGML_RESTRICT y = vy;
1160
+
1161
+ const int nb = n / QK4_NL;
1162
+
1163
+ int ib = 0;
1164
+ float sumf = 0;
1165
+
1166
+ #if defined(__VXE__) || defined(__VXE2__)
1167
+ const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
1168
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
1169
+
1170
+ for (; ib < nb; ++ib) {
1171
+ const block_iq4_nl * GGML_RESTRICT x0 = &x[ib];
1172
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
1173
+
1174
+ const uint8x16_t v_x = vec_xl(0, x0->qs);
1175
+ int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
1176
+ int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
1177
+
1178
+ v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
1179
+ v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
1180
+
1181
+ const int8x16_t v_yl = vec_xl(0 , y0->qs);
1182
+ const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
1183
+ const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
1184
+
1185
+ sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
1186
+ }
1187
+
1188
+ #endif
1189
+ for (; ib < nb; ++ib) {
1190
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
1191
+ int sumi1 = 0, sumi2 = 0;
1192
+ for (int j = 0; j < QK4_NL/2; ++j) {
1193
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
1194
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
1195
+ }
1196
+ sumf += d * (sumi1 + sumi2);
1197
+ }
1198
+ *s = sumf;
1199
+ }
1200
+
1201
+ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1202
+ assert(nrc == 1);
1203
+ UNUSED(nrc);
1204
+ UNUSED(bx);
1205
+ UNUSED(by);
1206
+ UNUSED(bs);
1207
+ assert(n % QK_K == 0);
1208
+
1209
+ const block_iq4_xs * GGML_RESTRICT x = vx;
1210
+ const block_q8_K * GGML_RESTRICT y = vy;
1211
+
1212
+ const int nb = n / QK_K;
1213
+
1214
+ #if defined(__VXE__) || defined(__VXE2__)
1215
+ const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
1216
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
1217
+
1218
+ float sumf = 0;
1219
+
1220
+ for (int ibl = 0; ibl < nb; ++ibl) {
1221
+ const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;
1222
+ const int8_t * GGML_RESTRICT q8 = y[ibl].qs;
1223
+
1224
+ uint16_t h = x[ibl].scales_h;
1225
+
1226
+ int sumi1 = 0, sumi2 = 0;
1227
+ for (int ib = 0; ib < QK_K/64; ++ib) {
1228
+ const uint8x16_t v_x0 = vec_xl(0 , q4);
1229
+ const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
1230
+ q4 += 32;
1231
+
1232
+ int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
1233
+ int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
1234
+ int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
1235
+ int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
1236
+
1237
+ v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
1238
+ v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
1239
+ v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
1240
+ v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
1241
+
1242
+ const int8x16_t v_y0 = vec_xl( 0, q8);
1243
+ const int8x16_t v_y1 = vec_xl(16, q8);
1244
+ const int8x16_t v_y2 = vec_xl(32, q8);
1245
+ const int8x16_t v_y3 = vec_xl(48, q8);
1246
+ q8 += 64;
1247
+
1248
+ int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
1249
+ int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
1250
+
1251
+ int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
1252
+ int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
1253
+
1254
+ h >>= 4;
1255
+
1256
+ sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
1257
+ sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
1258
+ }
1259
+
1260
+ sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
1261
+ }
1262
+
1263
+ *s = sumf;
1264
+
1265
+ #else
1266
+ float sumf = 0;
1267
+ for (int ibl = 0; ibl < nb; ++ibl) {
1268
+ const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
1269
+ uint16_t h = x[ibl].scales_h;
1270
+ const uint8_t * qs = x[ibl].qs;
1271
+ const int8_t * q8 = y[ibl].qs;
1272
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
1273
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
1274
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
1275
+ h >>= 4;
1276
+ const float d1 = d4d8*(ls1 - 32);
1277
+ const float d2 = d4d8*(ls2 - 32);
1278
+ int sumi1 = 0, sumi2 = 0;
1279
+ for (int j = 0; j < 16; ++j) {
1280
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
1281
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
1282
+ }
1283
+ sumf += d1 * (sumi1 + sumi2);
1284
+ qs += 16;
1285
+ q8 += 32;
1286
+ sumi1 = sumi2 = 0;
1287
+ for (int j = 0; j < 16; ++j) {
1288
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
1289
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
1290
+ }
1291
+ sumf += d2 * (sumi1 + sumi2);
1292
+ qs += 16;
1293
+ q8 += 32;
1294
+ }
1295
+ }
1296
+ *s = sumf;
1297
+ #endif
1298
+ }
1299
+
ggml/src/ggml-cpu/arch/wasm/quants.c ADDED
@@ -0,0 +1,1480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+
7
+ #include "../../quants.h"
8
+ #include "../../ggml-cpu-impl.h"
9
+
10
+ #include <math.h>
11
+ #include <string.h>
12
+ #include <assert.h>
13
+ #include <float.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ #if defined(__wasm_simd128__)
26
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
27
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
28
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
29
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
30
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
31
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
32
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
33
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
34
+
35
+ // precomputed tables for expanding 8bits to 8 bytes:
36
+ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
37
+ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
38
+ #endif
39
+
40
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
41
+ assert(QK8_0 == 32);
42
+ assert(k % QK8_0 == 0);
43
+ const int nb = k / QK8_0;
44
+
45
+ block_q8_0 * GGML_RESTRICT y = vy;
46
+
47
+ #if defined __wasm_simd128__
48
+ for (int i = 0; i < nb; i++) {
49
+ v128_t srcv [8];
50
+ v128_t asrcv[8];
51
+ v128_t amaxv[8];
52
+
53
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
54
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
55
+
56
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
57
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
58
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
59
+
60
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
61
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
62
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
63
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
64
+
65
+ const float d = amax / ((1 << 7) - 1);
66
+ const float id = d ? 1.0f/d : 0.0f;
67
+
68
+ y[i].d = GGML_FP32_TO_FP16(d);
69
+
70
+ for (int j = 0; j < 8; j++) {
71
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
72
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
73
+
74
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
75
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
76
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
77
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
78
+ }
79
+ }
80
+ #else
81
+ GGML_UNUSED(nb);
82
+ // scalar
83
+ quantize_row_q8_0_ref(x, y, k);
84
+ #endif
85
+ }
86
+
87
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
88
+ assert(k % QK8_1 == 0);
89
+ const int nb = k / QK8_1;
90
+
91
+ block_q8_1 * GGML_RESTRICT y = vy;
92
+ #if defined __wasm_simd128__
93
+ for (int i = 0; i < nb; i++) {
94
+ v128_t srcv [8];
95
+ v128_t asrcv[8];
96
+ v128_t amaxv[8];
97
+
98
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
99
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
100
+
101
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
102
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
103
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
104
+
105
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
106
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
107
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
108
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
109
+
110
+ const float d = amax / ((1 << 7) - 1);
111
+ const float id = d ? 1.0f/d : 0.0f;
112
+
113
+ y[i].d = GGML_FP32_TO_FP16(d);
114
+
115
+ v128_t accv = wasm_i32x4_splat(0);
116
+
117
+ for (int j = 0; j < 8; j++) {
118
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
119
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
120
+
121
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
122
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
123
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
124
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
125
+
126
+ accv = wasm_i32x4_add(accv, vi);
127
+ }
128
+
129
+ y[i].s = GGML_FP32_TO_FP16(
130
+ d * (wasm_i32x4_extract_lane(accv, 0) +
131
+ wasm_i32x4_extract_lane(accv, 1) +
132
+ wasm_i32x4_extract_lane(accv, 2) +
133
+ wasm_i32x4_extract_lane(accv, 3)));
134
+ }
135
+ #else
136
+ GGML_UNUSED(nb);
137
+ // scalar
138
+ quantize_row_q8_1_ref(x, y, k);
139
+ #endif
140
+ }
141
+
142
+ //===================================== Q8_K ==============================================
143
+
144
+ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
145
+ #ifdef __wasm_simd128__
146
+ assert(k % QK_K == 0);
147
+ const int64_t nb = k / QK_K;
148
+ block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
149
+
150
+ for (int i = 0; i < nb; i++) {
151
+ const float * x_block = x + i * QK_K;
152
+
153
+ v128_t min_vec = wasm_v128_load(x_block);
154
+ v128_t max_vec = min_vec;
155
+
156
+ for (int j = 4; j < QK_K; j += 4) {
157
+ v128_t x_vec = wasm_v128_load(x_block + j);
158
+ max_vec = wasm_f32x4_pmax(max_vec, x_vec);
159
+ min_vec = wasm_f32x4_pmin(min_vec, x_vec);
160
+ }
161
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
162
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
163
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
164
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
165
+ float max = wasm_f32x4_extract_lane(max_vec, 0);
166
+ float min = wasm_f32x4_extract_lane(min_vec, 0);
167
+ float amax = -min > max ? min : max;
168
+
169
+ if (amax == 0.0f) {
170
+ yc[i].d = 0.0f;
171
+ const v128_t zero = wasm_i8x16_splat(0);
172
+ for (int j = 0; j < QK_K; j += 16) {
173
+ wasm_v128_store(yc[i].qs + j, zero);
174
+ }
175
+ continue;
176
+ }
177
+
178
+ const float iscale = -127.0f / amax;
179
+ const v128_t scale_vec = wasm_f32x4_splat(iscale);
180
+
181
+ // Process 16 elements per iteration
182
+ for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
183
+ // Load and quantize 16 floats
184
+ v128_t x0 = wasm_v128_load(x_block + j);
185
+ v128_t x1 = wasm_v128_load(x_block + j + 4);
186
+ v128_t x2 = wasm_v128_load(x_block + j + 8);
187
+ v128_t x3 = wasm_v128_load(x_block + j + 12);
188
+
189
+ v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
190
+ v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
191
+ v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
192
+ v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
193
+
194
+ // Convert to i32 with saturation
195
+ v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
196
+ v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
197
+ v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
198
+ v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
199
+
200
+ // Pack into 16 i8 values
201
+ v128_t i8 = wasm_i8x16_narrow_i16x8(
202
+ wasm_i16x8_narrow_i32x4(i0, i1),
203
+ wasm_i16x8_narrow_i32x4(i2, i3)
204
+ );
205
+ wasm_v128_store(yc[i].qs + j, i8);
206
+
207
+ // Calculate bsums using SIMD
208
+ v128_t sum16 = wasm_i16x8_add(
209
+ wasm_i16x8_extend_low_i8x16(i8),
210
+ wasm_i16x8_extend_high_i8x16(i8)
211
+ );
212
+ v128_t sum32 = wasm_i32x4_add(
213
+ wasm_i32x4_extend_low_i16x8(sum16),
214
+ wasm_i32x4_extend_high_i16x8(sum16)
215
+ );
216
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
217
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
218
+ yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
219
+ }
220
+
221
+ yc[i].d = 1.0f / iscale;
222
+ }
223
+ #else
224
+ quantize_row_q8_K_ref(x, y, k);
225
+ #endif
226
+ }
227
+
228
+
229
+ //===================================== Dot products =================================
230
+
231
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
232
+ const int qk = QK8_0;
233
+ const int nb = n / qk;
234
+
235
+ assert(n % qk == 0);
236
+ assert(nrc == 1);
237
+ UNUSED(nrc);
238
+ UNUSED(bx);
239
+ UNUSED(by);
240
+ UNUSED(bs);
241
+
242
+ const block_q4_0 * GGML_RESTRICT x = vx;
243
+ const block_q8_0 * GGML_RESTRICT y = vy;
244
+
245
+ int ib = 0;
246
+ float sumf = 0;
247
+
248
+ #if defined __wasm_simd128__
249
+ v128_t sumv = wasm_f32x4_splat(0.0f);
250
+
251
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
252
+ const v128_t s8b = wasm_i8x16_splat(0x8);
253
+
254
+ for (; ib + 1 < nb; ib += 2) {
255
+ const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
256
+ const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
257
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
258
+ const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
259
+
260
+ // Load and process x0
261
+ v128_t v0_0 = wasm_v128_load(x0->qs);
262
+ v128_t v0_0l = wasm_v128_and(v0_0, m4b);
263
+ v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
264
+ v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
265
+ v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
266
+
267
+ // Load y0 vectors
268
+ v128_t y0_l = wasm_v128_load(y0->qs);
269
+ v128_t y0_h = wasm_v128_load(y0->qs + 16);
270
+
271
+ // Extend to i16x8 and compute dot products
272
+ v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
273
+ v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
274
+ v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
275
+ v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
276
+
277
+ v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
278
+ v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
279
+ v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
280
+ v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
281
+
282
+ v128_t dp0 = wasm_i32x4_add(
283
+ wasm_i32x4_add(
284
+ wasm_i32x4_dot_i16x8(dx0l, dy0ll),
285
+ wasm_i32x4_dot_i16x8(dx0h, dy0lh)
286
+ ),
287
+ wasm_i32x4_add(
288
+ wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
289
+ wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
290
+ )
291
+ );
292
+
293
+ // Load and process x1
294
+ v128_t v0_1 = wasm_v128_load(x1->qs);
295
+ v128_t v0_1l = wasm_v128_and(v0_1, m4b);
296
+ v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
297
+ v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
298
+ v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
299
+
300
+ // Load y1 vectors
301
+ v128_t y1_l = wasm_v128_load(y1->qs);
302
+ v128_t y1_h = wasm_v128_load(y1->qs + 16);
303
+
304
+ // Extend to i16x8 and compute dot products
305
+ v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
306
+ v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
307
+ v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
308
+ v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
309
+
310
+ v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
311
+ v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
312
+ v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
313
+ v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
314
+
315
+ v128_t dp1 = wasm_i32x4_add(
316
+ wasm_i32x4_add(
317
+ wasm_i32x4_dot_i16x8(dx1l, dy1ll),
318
+ wasm_i32x4_dot_i16x8(dx1h, dy1lh)
319
+ ),
320
+ wasm_i32x4_add(
321
+ wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
322
+ wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
323
+ )
324
+ );
325
+
326
+ // Accumulate results with scaling
327
+ float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
328
+ float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
329
+
330
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
331
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
332
+ }
333
+
334
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
335
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
336
+
337
+ #endif
338
+ for (; ib < nb; ++ib) {
339
+ int sumi0 = 0;
340
+ int sumi1 = 0;
341
+
342
+ for (int j = 0; j < qk/2; ++j) {
343
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
344
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
345
+
346
+ sumi0 += (v0 * y[ib].qs[j]);
347
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
348
+ }
349
+
350
+ int sumi = sumi0 + sumi1;
351
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
352
+ }
353
+
354
+ *s = sumf;
355
+ }
356
+
357
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
358
+ const int qk = QK8_0;
359
+ const int nb = n / qk;
360
+
361
+ int ib = 0;
362
+ float sumf = 0;
363
+
364
+ assert(n % qk == 0);
365
+ assert(qk == QK5_0);
366
+ assert(nrc == 1);
367
+ UNUSED(nrc);
368
+ UNUSED(bx);
369
+ UNUSED(by);
370
+ UNUSED(bs);
371
+
372
+ const block_q5_0 * GGML_RESTRICT x = vx;
373
+ const block_q8_0 * GGML_RESTRICT y = vy;
374
+
375
+ #if defined __wasm_simd128__
376
+ v128_t sumv = wasm_f32x4_splat(0.0f);
377
+
378
+ uint32_t qh_;
379
+ uint64_t tmp[4];
380
+
381
+ // TODO: check if unrolling this is better
382
+ for (; ib < nb; ++ib) {
383
+ const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
384
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
385
+
386
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
387
+
388
+ // extract the 5th bit
389
+ memcpy(&qh_, x0->qh, sizeof(qh_));
390
+
391
+ tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
392
+ tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
393
+ tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
394
+ tmp[3] = table_b2b_1[(qh_ >> 24) ];
395
+
396
+ const v128_t qhl = wasm_v128_load(tmp + 0);
397
+ const v128_t qhh = wasm_v128_load(tmp + 2);
398
+
399
+ const v128_t v0 = wasm_v128_load(x0->qs);
400
+
401
+ // 4-bit -> 8-bit
402
+ const v128_t v0l = wasm_v128_and (v0, m4b);
403
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
404
+
405
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
406
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
407
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
408
+
409
+ // load y
410
+ const v128_t v1l = wasm_v128_load(y0->qs);
411
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
412
+
413
+ // int8x16 -> int16x8
414
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
415
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
416
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
417
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
418
+
419
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
420
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
421
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
422
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
423
+
424
+ // dot product
425
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
426
+ wasm_i32x4_add(
427
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
428
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
429
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
430
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
431
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
432
+ }
433
+
434
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
435
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
436
+
437
+ #endif
438
+ for (; ib < nb; ++ib) {
439
+ uint32_t qh;
440
+ memcpy(&qh, x[ib].qh, sizeof(qh));
441
+
442
+ int sumi0 = 0;
443
+ int sumi1 = 0;
444
+
445
+ for (int j = 0; j < qk/2; ++j) {
446
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
447
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
448
+
449
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
450
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
451
+
452
+ sumi0 += (x0 * y[ib].qs[j]);
453
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
454
+ }
455
+
456
+ int sumi = sumi0 + sumi1;
457
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
458
+ }
459
+
460
+ *s = sumf;
461
+ }
462
+
463
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
464
+ const int qk = QK8_1;
465
+ const int nb = n / qk;
466
+
467
+ int ib = 0;
468
+ float sumf = 0;
469
+
470
+ assert(n % qk == 0);
471
+ assert(qk == QK5_1);
472
+ assert(nrc == 1);
473
+ UNUSED(nrc);
474
+ UNUSED(bx);
475
+ UNUSED(by);
476
+ UNUSED(bs);
477
+
478
+ const block_q5_1 * GGML_RESTRICT x = vx;
479
+ const block_q8_1 * GGML_RESTRICT y = vy;
480
+
481
+ #if defined __wasm_simd128__
482
+ v128_t sumv = wasm_f32x4_splat(0.0f);
483
+
484
+ float summs = 0.0f;
485
+
486
+ uint32_t qh_;
487
+ uint64_t tmp[4];
488
+
489
+ // TODO: check if unrolling this is better
490
+ for (; ib < nb; ++ib) {
491
+ const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
492
+ const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
493
+
494
+ summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
495
+
496
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
497
+
498
+ // extract the 5th bit
499
+ memcpy(&qh_, x0->qh, sizeof(qh_));
500
+
501
+ tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
502
+ tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
503
+ tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
504
+ tmp[3] = table_b2b_0[(qh_ >> 24) ];
505
+
506
+ const v128_t qhl = wasm_v128_load(tmp + 0);
507
+ const v128_t qhh = wasm_v128_load(tmp + 2);
508
+
509
+ const v128_t v0 = wasm_v128_load(x0->qs);
510
+
511
+ // 4-bit -> 8-bit
512
+ const v128_t v0l = wasm_v128_and (v0, m4b);
513
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
514
+
515
+ // add high bit
516
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
517
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
518
+
519
+ // load y
520
+ const v128_t v1l = wasm_v128_load(y0->qs);
521
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
522
+
523
+ // int8x16 -> int16x8
524
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
525
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
526
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
527
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
528
+
529
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
530
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
531
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
532
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
533
+
534
+ // dot product
535
+ sumv = wasm_f32x4_add(sumv,
536
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
537
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
538
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
539
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
540
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
541
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
542
+ }
543
+
544
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
545
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
546
+
547
+ #endif
548
+ for (; ib < nb; ++ib) {
549
+ uint32_t qh;
550
+ memcpy(&qh, x[ib].qh, sizeof(qh));
551
+
552
+ int sumi0 = 0;
553
+ int sumi1 = 0;
554
+
555
+ for (int j = 0; j < qk/2; ++j) {
556
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
557
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
558
+
559
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
560
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
561
+
562
+ sumi0 += (x0 * y[ib].qs[j]);
563
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
564
+ }
565
+
566
+ int sumi = sumi0 + sumi1;
567
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
568
+ }
569
+
570
+ *s = sumf;
571
+ }
572
+
573
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
574
+ const int qk = QK8_0;
575
+ const int nb = n / qk;
576
+
577
+ assert(n % qk == 0);
578
+ assert(nrc == 1);
579
+ UNUSED(nrc);
580
+ UNUSED(bx);
581
+ UNUSED(by);
582
+ UNUSED(bs);
583
+
584
+ const block_q8_0 * GGML_RESTRICT x = vx;
585
+ const block_q8_0 * GGML_RESTRICT y = vy;
586
+
587
+ int ib = 0;
588
+ float sumf = 0;
589
+
590
+ #if defined __wasm_simd128__
591
+ v128_t sumv = wasm_f32x4_splat(0.0f);
592
+
593
+ for (; ib < nb; ++ib) {
594
+ const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
595
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
596
+
597
+ const v128_t x0_0 = wasm_v128_load(x0->qs);
598
+ const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
599
+ const v128_t y0_0 = wasm_v128_load(y0->qs);
600
+ const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
601
+
602
+ // Extend 8-bit to 16-bit
603
+ const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
604
+ const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
605
+ const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
606
+ const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
607
+
608
+ const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
609
+ const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
610
+ const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
611
+ const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
612
+
613
+ // Compute dot products
614
+ const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
615
+ const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
616
+ const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
617
+ const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
618
+
619
+ // Sum all dot products
620
+ const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
621
+
622
+ // Convert to float and accumulate
623
+ const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
624
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
625
+ }
626
+
627
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
628
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
629
+
630
+ #endif
631
+ for (; ib < nb; ++ib) {
632
+ int sumi = 0;
633
+
634
+ for (int j = 0; j < qk; j++) {
635
+ sumi += x[ib].qs[j]*y[ib].qs[j];
636
+ }
637
+
638
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
639
+ }
640
+
641
+ *s = sumf;
642
+ }
643
+
644
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
645
+ assert(nrc == 1);
646
+ UNUSED(nrc);
647
+ UNUSED(bx);
648
+ UNUSED(by);
649
+ UNUSED(bs);
650
+
651
+ const block_q2_K * GGML_RESTRICT x = vx;
652
+ const block_q8_K * GGML_RESTRICT y = vy;
653
+
654
+ const int nb = n / QK_K;
655
+
656
+ #if defined __wasm_simd128__
657
+ float sumf = 0;
658
+
659
+ for (int i = 0; i < nb; ++i) {
660
+ const uint8_t * q2 = x[i].qs;
661
+ const int8_t * q8 = y[i].qs;
662
+ const uint8_t * sc = x[i].scales;
663
+
664
+ // Vectorized summs calculation
665
+ v128_t summs_vec = wasm_i32x4_splat(0);
666
+ {
667
+ v128_t sc_vec = wasm_v128_load(sc);
668
+ v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
669
+
670
+ v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
671
+ v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
672
+
673
+ v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
674
+ v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
675
+
676
+ summs_vec = wasm_i32x4_add(
677
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
678
+ wasm_i32x4_dot_i16x8(sc_high, bsums2)),
679
+ summs_vec
680
+ );
681
+
682
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
683
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
684
+ }
685
+ int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
686
+
687
+ // Vectorized isum calculation
688
+ int32_t isum = 0;
689
+ const uint8_t * sc_ptr = sc;
690
+ const int k_iters = QK_K/128;
691
+
692
+ for (int k = 0; k < k_iters; ++k) {
693
+ v128_t isum_vec = wasm_i32x4_splat(0);
694
+ int shift = 0;
695
+
696
+ for (int j = 0; j < 4; ++j) {
697
+ const int d0 = (sc_ptr[0] & 0xF);
698
+ const int d1 = (sc_ptr[1] & 0xF);
699
+ sc_ptr += 2;
700
+
701
+ // Process first 16 elements
702
+ v128_t q2_0 = wasm_v128_load(q2);
703
+ v128_t q8_0 = wasm_v128_load(q8);
704
+ v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
705
+ v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
706
+
707
+ // Process next 16 elements
708
+ v128_t q2_1 = wasm_v128_load(q2 + 16);
709
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
710
+ v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
711
+ v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
712
+
713
+ // Calculate dot products
714
+ v128_t p0 = wasm_i32x4_dot_i16x8(
715
+ wasm_i16x8_extend_low_i8x16(q8_0),
716
+ wasm_i16x8_extend_low_i8x16(q2_bits_0)
717
+ );
718
+ v128_t p1 = wasm_i32x4_dot_i16x8(
719
+ wasm_i16x8_extend_high_i8x16(q8_0),
720
+ wasm_i16x8_extend_high_i8x16(q2_bits_0)
721
+ );
722
+ v128_t p2 = wasm_i32x4_dot_i16x8(
723
+ wasm_i16x8_extend_low_i8x16(q8_1),
724
+ wasm_i16x8_extend_low_i8x16(q2_bits_1)
725
+ );
726
+ v128_t p3 = wasm_i32x4_dot_i16x8(
727
+ wasm_i16x8_extend_high_i8x16(q8_1),
728
+ wasm_i16x8_extend_high_i8x16(q2_bits_1)
729
+ );
730
+
731
+ // Accumulate scaled results
732
+ v128_t scaled = wasm_i32x4_add(
733
+ wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
734
+ wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
735
+ );
736
+
737
+ isum_vec = wasm_i32x4_add(isum_vec, scaled);
738
+ q8 += 32;
739
+ shift += 2;
740
+ }
741
+ q2 += 32;
742
+
743
+ // Horizontal sum of isum_vec
744
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
745
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
746
+ isum += wasm_i32x4_extract_lane(isum_vec, 0);
747
+ }
748
+
749
+ const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
750
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
751
+ sumf += dall * isum - dmin * summs;
752
+ }
753
+
754
+ *s = sumf;
755
+
756
+ #else
757
+
758
+ float sumf = 0;
759
+
760
+ for (int i = 0; i < nb; ++i) {
761
+
762
+ const uint8_t * q2 = x[i].qs;
763
+ const int8_t * q8 = y[i].qs;
764
+ const uint8_t * sc = x[i].scales;
765
+
766
+ int summs = 0;
767
+ for (int j = 0; j < 16; ++j) {
768
+ summs += y[i].bsums[j] * (sc[j] >> 4);
769
+ }
770
+
771
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
772
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
773
+
774
+ int isum = 0;
775
+ int is = 0;
776
+ int d;
777
+ for (int k = 0; k < QK_K/128; ++k) {
778
+ int shift = 0;
779
+ for (int j = 0; j < 4; ++j) {
780
+ d = sc[is++] & 0xF;
781
+ int isuml = 0;
782
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
783
+ isum += d * isuml;
784
+ d = sc[is++] & 0xF;
785
+ isuml = 0;
786
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
787
+ isum += d * isuml;
788
+ shift += 2;
789
+ q8 += 32;
790
+ }
791
+ q2 += 32;
792
+ }
793
+ sumf += dall * isum - dmin * summs;
794
+ }
795
+ *s = sumf;
796
+ #endif
797
+ }
798
+
799
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
800
+ assert(n % QK_K == 0);
801
+ assert(nrc == 1);
802
+ UNUSED(nrc);
803
+ UNUSED(bx);
804
+ UNUSED(by);
805
+ UNUSED(bs);
806
+
807
+ const uint32_t kmask1 = 0x03030303;
808
+ const uint32_t kmask2 = 0x0f0f0f0f;
809
+
810
+ const block_q3_K * GGML_RESTRICT x = vx;
811
+ const block_q8_K * GGML_RESTRICT y = vy;
812
+
813
+ const int nb = n / QK_K;
814
+
815
+ #if defined __wasm_simd128__
816
+ int8_t aux8[QK_K];
817
+ float sums[8] = {0};
818
+ uint32_t auxs[4];
819
+
820
+ float sumf = 0;
821
+ for (int i = 0; i < nb; ++i) {
822
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
823
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
824
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
825
+
826
+ // Process blocks with SIMD
827
+ int8_t * a = aux8;
828
+ uint8_t m = 1;
829
+ for (int j = 0; j < QK_K; j += 128) {
830
+ for (int shift = 0; shift <= 6; shift += 2) {
831
+ v128_t v_m = wasm_i8x16_splat(m);
832
+ for (int l = 0; l < 32; l += 16) {
833
+ v128_t v_q3 = wasm_v128_load(q3 + l);
834
+ v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
835
+ v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
836
+
837
+ v128_t v_hm = wasm_v128_load(hm + l);
838
+ v128_t v_mask = wasm_v128_and(v_hm, v_m);
839
+ v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
840
+
841
+ v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
842
+ wasm_v128_store(a + l, v_low2);
843
+ }
844
+ a += 32;
845
+ m <<= 1;
846
+ }
847
+ q3 += 32;
848
+ }
849
+
850
+ // Extract scales
851
+ memcpy(auxs, x[i].scales, 12);
852
+ uint32_t tmp = auxs[2];
853
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
854
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
855
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
856
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
857
+ const int8_t * scales = (const int8_t *)auxs;
858
+
859
+ // SIMD dot product with register accumulators
860
+ v128_t v_acc0 = wasm_i32x4_splat(0);
861
+ v128_t v_acc1 = wasm_i32x4_splat(0);
862
+ a = aux8;
863
+ for (int j = 0; j < QK_K/16; ++j) {
864
+ const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
865
+
866
+ // Process 16 elements per iteration
867
+ for (int k = 0; k < 2; ++k) {
868
+ const v128_t v_q8 = wasm_i16x8_load8x8(q8);
869
+ const v128_t v_a = wasm_i16x8_load8x8(a);
870
+
871
+ v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
872
+ v_prod = wasm_i16x8_mul(v_prod, v_scale);
873
+
874
+ v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
875
+ v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
876
+
877
+ q8 += 8;
878
+ a += 8;
879
+ }
880
+ }
881
+
882
+ // Accumulate results
883
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
884
+ const v128_t v_d = wasm_f32x4_splat(d);
885
+ v128_t v_sum = wasm_f32x4_add(
886
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
887
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
888
+ );
889
+
890
+ // Accumulate into sums vector
891
+ wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
892
+ }
893
+
894
+ // Horizontal sum
895
+ v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
896
+ sumf = wasm_f32x4_extract_lane(v_sum, 0) +
897
+ wasm_f32x4_extract_lane(v_sum, 1) +
898
+ wasm_f32x4_extract_lane(v_sum, 2) +
899
+ wasm_f32x4_extract_lane(v_sum, 3);
900
+
901
+ *s = sumf;
902
+
903
+ #else
904
+ // scalar version
905
+ // This function is written like this so the compiler can manage to vectorize most of it
906
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
907
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
908
+ // The ideal situation would be if we could just write the code once, and the compiler would
909
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
910
+ // write vectorized versions for AVX, ARM_NEON, etc.
911
+
912
+ int8_t aux8[QK_K];
913
+ int16_t aux16[8];
914
+ float sums [8];
915
+ int32_t aux32[8];
916
+ memset(sums, 0, 8*sizeof(float));
917
+
918
+ uint32_t auxs[4];
919
+ const int8_t * scales = (const int8_t*)auxs;
920
+
921
+ float sumf = 0;
922
+ for (int i = 0; i < nb; ++i) {
923
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
924
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
925
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
926
+ memset(aux32, 0, 8*sizeof(int32_t));
927
+ int8_t * GGML_RESTRICT a = aux8;
928
+ uint8_t m = 1;
929
+ for (int j = 0; j < QK_K; j += 128) {
930
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
931
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
932
+ a += 32; m <<= 1;
933
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
934
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
935
+ a += 32; m <<= 1;
936
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
937
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
938
+ a += 32; m <<= 1;
939
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
940
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
941
+ a += 32; m <<= 1;
942
+ q3 += 32;
943
+ }
944
+ a = aux8;
945
+
946
+ memcpy(auxs, x[i].scales, 12);
947
+ uint32_t tmp = auxs[2];
948
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
949
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
950
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
951
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
952
+ for (int j = 0; j < QK_K/16; ++j) {
953
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
954
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
955
+ q8 += 8; a += 8;
956
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
957
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
958
+ q8 += 8; a += 8;
959
+ }
960
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
961
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
962
+ }
963
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
964
+ *s = sumf;
965
+
966
+ #endif
967
+
968
+ }
969
+
970
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
971
+ assert(n % QK_K == 0);
972
+ assert(nrc == 1);
973
+ UNUSED(nrc);
974
+ UNUSED(bx);
975
+ UNUSED(by);
976
+ UNUSED(bs);
977
+
978
+ const block_q4_K * GGML_RESTRICT x = vx;
979
+ const block_q8_K * GGML_RESTRICT y = vy;
980
+
981
+ const int nb = n / QK_K;
982
+
983
+ static const uint32_t kmask1 = 0x3f3f3f3f;
984
+ static const uint32_t kmask2 = 0x0f0f0f0f;
985
+ static const uint32_t kmask3 = 0x03030303;
986
+
987
+ uint32_t utmp[4];
988
+
989
+ #if defined __wasm_simd128__
990
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
991
+ float sumf = 0;
992
+
993
+ for (int i = 0; i < nb; ++i) {
994
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
995
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
996
+
997
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
998
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
999
+
1000
+ // Process scales and mins
1001
+ memcpy(utmp, x[i].scales, 12);
1002
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1003
+ const uint32_t uaux = utmp[1] & kmask1;
1004
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1005
+ utmp[2] = uaux;
1006
+ utmp[0] &= kmask1;
1007
+
1008
+ // Sum mins * q8sums
1009
+ int32_t sumi = 0;
1010
+ const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1011
+ const uint8_t * m = (const uint8_t *)&utmp[2];
1012
+ for (int j = 0; j < 16; j += 2) {
1013
+ sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
1014
+ }
1015
+ sumf -= dmin * sumi;
1016
+
1017
+ int32_t sumi1 = 0;
1018
+ int32_t sumi2 = 0;
1019
+
1020
+ for (int j = 0; j < QK_K/64; ++j) {
1021
+ // Load 64 4-bit weights (32 bytes)
1022
+ const v128_t q4x0 = wasm_v128_load(q4);
1023
+ const v128_t q4x1 = wasm_v128_load(q4 + 16);
1024
+ q4 += 32;
1025
+
1026
+ // Split into low/high nibbles
1027
+ const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
1028
+ const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
1029
+ const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
1030
+ const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
1031
+
1032
+ // Load 64 8-bit values (64 bytes)
1033
+ const v128_t q8x0 = wasm_v128_load(q8);
1034
+ const v128_t q8x1 = wasm_v128_load(q8 + 16);
1035
+ const v128_t q8x2 = wasm_v128_load(q8 + 32);
1036
+ const v128_t q8x3 = wasm_v128_load(q8 + 48);
1037
+ q8 += 64;
1038
+
1039
+ // Low nibble products
1040
+ v128_t vacc1 = wasm_i32x4_dot_i16x8(
1041
+ wasm_i16x8_extend_low_i8x16(q4l0),
1042
+ wasm_i16x8_extend_low_i8x16(q8x0)
1043
+ );
1044
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1045
+ wasm_i16x8_extend_high_i8x16(q4l0),
1046
+ wasm_i16x8_extend_high_i8x16(q8x0)
1047
+ ));
1048
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1049
+ wasm_i16x8_extend_low_i8x16(q4l1),
1050
+ wasm_i16x8_extend_low_i8x16(q8x1)
1051
+ ));
1052
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1053
+ wasm_i16x8_extend_high_i8x16(q4l1),
1054
+ wasm_i16x8_extend_high_i8x16(q8x1)
1055
+ ));
1056
+
1057
+ // High nibble products
1058
+ v128_t vacc2 = wasm_i32x4_dot_i16x8(
1059
+ wasm_i16x8_extend_low_i8x16(q4h0),
1060
+ wasm_i16x8_extend_low_i8x16(q8x2)
1061
+ );
1062
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1063
+ wasm_i16x8_extend_high_i8x16(q4h0),
1064
+ wasm_i16x8_extend_high_i8x16(q8x2)
1065
+ ));
1066
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1067
+ wasm_i16x8_extend_low_i8x16(q4h1),
1068
+ wasm_i16x8_extend_low_i8x16(q8x3)
1069
+ ));
1070
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1071
+ wasm_i16x8_extend_high_i8x16(q4h1),
1072
+ wasm_i16x8_extend_high_i8x16(q8x3)
1073
+ ));
1074
+
1075
+ // Accumulate scaled results
1076
+ int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
1077
+ wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
1078
+ sumi1 += vacc1_sum * scales[2*j];
1079
+
1080
+ int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
1081
+ wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
1082
+ sumi2 += vacc2_sum * scales[2*j+1];
1083
+ }
1084
+
1085
+ sumf += d * (sumi1 + sumi2);
1086
+ }
1087
+
1088
+ *s = sumf;
1089
+
1090
+ #else
1091
+
1092
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1093
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1094
+
1095
+ int8_t aux8[QK_K];
1096
+ int16_t aux16[8];
1097
+ float sums [8];
1098
+ int32_t aux32[8];
1099
+ memset(sums, 0, 8*sizeof(float));
1100
+
1101
+ float sumf = 0;
1102
+ for (int i = 0; i < nb; ++i) {
1103
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1104
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1105
+ memset(aux32, 0, 8*sizeof(int32_t));
1106
+ int8_t * GGML_RESTRICT a = aux8;
1107
+ for (int j = 0; j < QK_K/64; ++j) {
1108
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1109
+ a += 32;
1110
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1111
+ a += 32; q4 += 32;
1112
+ }
1113
+ memcpy(utmp, x[i].scales, 12);
1114
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1115
+ const uint32_t uaux = utmp[1] & kmask1;
1116
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1117
+ utmp[2] = uaux;
1118
+ utmp[0] &= kmask1;
1119
+
1120
+ int sumi = 0;
1121
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1122
+ a = aux8;
1123
+ int is = 0;
1124
+ for (int j = 0; j < QK_K/32; ++j) {
1125
+ int32_t scale = scales[is++];
1126
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1127
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1128
+ q8 += 8; a += 8;
1129
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1130
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1131
+ q8 += 8; a += 8;
1132
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1133
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1134
+ q8 += 8; a += 8;
1135
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1136
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1137
+ q8 += 8; a += 8;
1138
+ }
1139
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1140
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1141
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1142
+ sumf -= dmin * sumi;
1143
+ }
1144
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1145
+ *s = sumf;
1146
+ #endif
1147
+ }
1148
+
1149
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1150
+ assert(n % QK_K == 0);
1151
+ assert(nrc == 1);
1152
+ UNUSED(nrc);
1153
+ UNUSED(bx);
1154
+ UNUSED(by);
1155
+ UNUSED(bs);
1156
+
1157
+ const block_q5_K * GGML_RESTRICT x = vx;
1158
+ const block_q8_K * GGML_RESTRICT y = vy;
1159
+
1160
+ const int nb = n / QK_K;
1161
+
1162
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1163
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1164
+ static const uint32_t kmask3 = 0x03030303;
1165
+
1166
+ uint32_t utmp[4];
1167
+
1168
+ #if defined __wasm_simd128__
1169
+ //const uint8_t * scales = (const uint8_t*)&utmp[0];
1170
+ float sumf = 0;
1171
+
1172
+ for (int i = 0; i < nb; ++i) {
1173
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1174
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
1175
+
1176
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1177
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1178
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1179
+
1180
+ // Process scales and mins
1181
+ memcpy(utmp, x[i].scales, 12);
1182
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1183
+ const uint32_t uaux = utmp[1] & kmask1;
1184
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1185
+ utmp[2] = uaux;
1186
+ utmp[0] &= kmask1;
1187
+
1188
+ // Sum mins * q8sums
1189
+ int32_t sumi_mins = 0;
1190
+ const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1191
+ const uint8_t * m = (const uint8_t *)&utmp[2];
1192
+ for (int j = 0; j < 16; j += 2) {
1193
+ sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
1194
+ }
1195
+ sumf -= dmin * sumi_mins; // Correct subtraction
1196
+
1197
+ v128_t qh0 = wasm_v128_load(qh);
1198
+ v128_t qh1 = wasm_v128_load(qh + 16);
1199
+ const uint8_t * sc = (const uint8_t *)utmp;
1200
+
1201
+ int32_t sumi = 0;
1202
+
1203
+ for (int j = 0; j < QK_K/64; ++j) {
1204
+ const int shift = j * 2;
1205
+ v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
1206
+ v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
1207
+
1208
+ v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
1209
+ v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
1210
+ v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
1211
+ v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
1212
+
1213
+ v128_t q5_0 = wasm_v128_load(q5);
1214
+ v128_t q5_1 = wasm_v128_load(q5 + 16);
1215
+ q5 += 32;
1216
+
1217
+ v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
1218
+ v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
1219
+ v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
1220
+ v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
1221
+
1222
+ v128_t q8_0 = wasm_v128_load(q8);
1223
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
1224
+ v128_t q8_2 = wasm_v128_load(q8 + 32);
1225
+ v128_t q8_3 = wasm_v128_load(q8 + 48);
1226
+ q8 += 64;
1227
+
1228
+ // Process low quants
1229
+ v128_t pl0 = wasm_i32x4_dot_i16x8(
1230
+ wasm_i16x8_extend_low_i8x16(q5l_0),
1231
+ wasm_i16x8_extend_low_i8x16(q8_0)
1232
+ );
1233
+ pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
1234
+ wasm_i16x8_extend_high_i8x16(q5l_0),
1235
+ wasm_i16x8_extend_high_i8x16(q8_0)
1236
+ ));
1237
+ v128_t pl1 = wasm_i32x4_dot_i16x8(
1238
+ wasm_i16x8_extend_low_i8x16(q5l_1),
1239
+ wasm_i16x8_extend_low_i8x16(q8_1)
1240
+ );
1241
+ pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
1242
+ wasm_i16x8_extend_high_i8x16(q5l_1),
1243
+ wasm_i16x8_extend_high_i8x16(q8_1)
1244
+ ));
1245
+ v128_t sum_low = wasm_i32x4_add(pl0, pl1);
1246
+
1247
+ // Process high quants
1248
+ v128_t ph0 = wasm_i32x4_dot_i16x8(
1249
+ wasm_i16x8_extend_low_i8x16(q5h_0),
1250
+ wasm_i16x8_extend_low_i8x16(q8_2)
1251
+ );
1252
+ ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
1253
+ wasm_i16x8_extend_high_i8x16(q5h_0),
1254
+ wasm_i16x8_extend_high_i8x16(q8_2)
1255
+ ));
1256
+ v128_t ph1 = wasm_i32x4_dot_i16x8(
1257
+ wasm_i16x8_extend_low_i8x16(q5h_1),
1258
+ wasm_i16x8_extend_low_i8x16(q8_3)
1259
+ );
1260
+ ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
1261
+ wasm_i16x8_extend_high_i8x16(q5h_1),
1262
+ wasm_i16x8_extend_high_i8x16(q8_3)
1263
+ ));
1264
+ v128_t sum_high = wasm_i32x4_add(ph0, ph1);
1265
+
1266
+ // Accumulate with scale factors
1267
+ int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
1268
+ wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
1269
+ int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
1270
+ wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
1271
+
1272
+ sumi += sl * sc[2*j] + sh * sc[2*j+1];
1273
+ }
1274
+
1275
+ sumf += d * sumi;
1276
+ }
1277
+
1278
+ *s = sumf;
1279
+
1280
+ #else
1281
+
1282
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1283
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1284
+
1285
+ int8_t aux8[QK_K];
1286
+ int16_t aux16[8];
1287
+ float sums [8];
1288
+ int32_t aux32[8];
1289
+ memset(sums, 0, 8*sizeof(float));
1290
+
1291
+ float sumf = 0;
1292
+ for (int i = 0; i < nb; ++i) {
1293
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1294
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1295
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1296
+ memset(aux32, 0, 8*sizeof(int32_t));
1297
+ int8_t * GGML_RESTRICT a = aux8;
1298
+ uint8_t m = 1;
1299
+ for (int j = 0; j < QK_K/64; ++j) {
1300
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1301
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1302
+ a += 32; m <<= 1;
1303
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1304
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1305
+ a += 32; m <<= 1;
1306
+ q4 += 32;
1307
+ }
1308
+ memcpy(utmp, x[i].scales, 12);
1309
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1310
+ const uint32_t uaux = utmp[1] & kmask1;
1311
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1312
+ utmp[2] = uaux;
1313
+ utmp[0] &= kmask1;
1314
+
1315
+ int sumi = 0;
1316
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1317
+ a = aux8;
1318
+ int is = 0;
1319
+ for (int j = 0; j < QK_K/32; ++j) {
1320
+ int32_t scale = scales[is++];
1321
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1322
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1323
+ q8 += 8; a += 8;
1324
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1325
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1326
+ q8 += 8; a += 8;
1327
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1328
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1329
+ q8 += 8; a += 8;
1330
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1331
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1332
+ q8 += 8; a += 8;
1333
+ }
1334
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1335
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1336
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1337
+ sumf -= dmin * sumi;
1338
+ }
1339
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1340
+ *s = sumf;
1341
+ #endif
1342
+ }
1343
+
1344
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1345
+ assert(n % QK_K == 0);
1346
+ assert(nrc == 1);
1347
+ UNUSED(nrc);
1348
+ UNUSED(bx);
1349
+ UNUSED(by);
1350
+ UNUSED(bs);
1351
+
1352
+ const block_q6_K * GGML_RESTRICT x = vx;
1353
+ const block_q8_K * GGML_RESTRICT y = vy;
1354
+
1355
+ const int nb = n / QK_K;
1356
+
1357
+ #if defined __wasm_simd128__
1358
+ int8_t aux8[QK_K] __attribute__((aligned(16)));
1359
+ int32_t aux32[8] __attribute__((aligned(16))) = {0};
1360
+ float sums[8] __attribute__((aligned(16))) = {0};
1361
+
1362
+ for (int i = 0; i < nb; ++i) {
1363
+ // Unpack 6-bit quantized data into aux8 (unchanged)
1364
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1365
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1366
+ int8_t * a = aux8;
1367
+ for (int j = 0; j < QK_K; j += 128) {
1368
+ for (int l = 0; l < 32; ++l) {
1369
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1370
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1371
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1372
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1373
+ }
1374
+ a += 128;
1375
+ q4 += 64;
1376
+ qh += 32;
1377
+ }
1378
+
1379
+ const int8_t * GGML_RESTRICT a_ptr = aux8;
1380
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1381
+ v128_t acc0 = wasm_i32x4_splat(0);
1382
+ v128_t acc1 = wasm_i32x4_splat(0);
1383
+
1384
+ for (int j = 0; j < QK_K/16; ++j) {
1385
+ const int scale = x[i].scales[j];
1386
+ const v128_t vscale = wasm_i32x4_splat(scale);
1387
+
1388
+ // Load 16 elements from a and q8
1389
+ const v128_t a_vec = wasm_v128_load(a_ptr);
1390
+ const v128_t q8_vec = wasm_v128_load(q8);
1391
+
1392
+ // Process low 8 elements
1393
+ v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
1394
+ v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
1395
+ v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
1396
+ v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
1397
+ v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
1398
+
1399
+ // Process high 8 elements
1400
+ v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
1401
+ v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
1402
+ v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
1403
+ v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
1404
+ v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
1405
+
1406
+ // Scale and accumulate
1407
+ prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
1408
+ prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
1409
+ prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
1410
+ prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
1411
+
1412
+ acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
1413
+ acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
1414
+
1415
+ a_ptr += 16;
1416
+ q8 += 16;
1417
+ }
1418
+
1419
+ // Store accumulated results
1420
+ wasm_v128_store(&aux32[0], acc0);
1421
+ wasm_v128_store(&aux32[4], acc1);
1422
+
1423
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1424
+ for (int l = 0; l < 8; ++l) {
1425
+ sums[l] += d * aux32[l];
1426
+ }
1427
+ }
1428
+
1429
+ // Sum final results
1430
+ float sumf = 0;
1431
+ for (int l = 0; l < 8; ++l) {
1432
+ sumf += sums[l];
1433
+ }
1434
+ *s = sumf;
1435
+
1436
+ #else
1437
+
1438
+ int8_t aux8[QK_K];
1439
+ int16_t aux16[8];
1440
+ float sums [8];
1441
+ int32_t aux32[8];
1442
+ memset(sums, 0, 8*sizeof(float));
1443
+
1444
+ float sumf = 0;
1445
+ for (int i = 0; i < nb; ++i) {
1446
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1447
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1448
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1449
+ memset(aux32, 0, 8*sizeof(int32_t));
1450
+ int8_t * GGML_RESTRICT a = aux8;
1451
+ for (int j = 0; j < QK_K; j += 128) {
1452
+ for (int l = 0; l < 32; ++l) {
1453
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1454
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1455
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1456
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1457
+ }
1458
+ a += 128;
1459
+ q4 += 64;
1460
+ qh += 32;
1461
+ }
1462
+ a = aux8;
1463
+ int is = 0;
1464
+ for (int j = 0; j < QK_K/16; ++j) {
1465
+ int scale = x[i].scales[is++];
1466
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1467
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1468
+ q8 += 8; a += 8;
1469
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1470
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1471
+ q8 += 8; a += 8;
1472
+ }
1473
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1474
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1475
+ }
1476
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1477
+ *s = sumf;
1478
+ #endif
1479
+ }
1480
+
ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-backend-impl.h"
2
+
3
+ #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
4
+
5
+ #ifdef _MSC_VER
6
+ #include <intrin.h>
7
+ #endif
8
+
9
+ #include <cstring>
10
+ #include <vector>
11
+ #include <bitset>
12
+ #include <array>
13
+ #include <string>
14
+
15
+ // ref: https://cdrdv2-public.intel.com/782156/325383-sdm-vol-2abcd.pdf
16
+ struct cpuid_x86 {
17
+ bool SSE3(void) { return f_1_ecx[0]; }
18
+ bool PCLMULQDQ(void) { return f_1_ecx[1]; }
19
+ bool MONITOR(void) { return f_1_ecx[3]; }
20
+ bool SSSE3(void) { return f_1_ecx[9]; }
21
+ bool FMA(void) { return f_1_ecx[12]; }
22
+ bool CMPXCHG16B(void) { return f_1_ecx[13]; }
23
+ bool SSE41(void) { return f_1_ecx[19]; }
24
+ bool SSE42(void) { return f_1_ecx[20]; }
25
+ bool MOVBE(void) { return f_1_ecx[22]; }
26
+ bool POPCNT(void) { return f_1_ecx[23]; }
27
+ bool AES(void) { return f_1_ecx[25]; }
28
+ bool XSAVE(void) { return f_1_ecx[26]; }
29
+ bool OSXSAVE(void) { return f_1_ecx[27]; }
30
+ bool AVX(void) { return f_1_ecx[28]; }
31
+ bool F16C(void) { return f_1_ecx[29]; }
32
+ bool RDRAND(void) { return f_1_ecx[30]; }
33
+
34
+ bool MSR(void) { return f_1_edx[5]; }
35
+ bool CX8(void) { return f_1_edx[8]; }
36
+ bool SEP(void) { return f_1_edx[11]; }
37
+ bool CMOV(void) { return f_1_edx[15]; }
38
+ bool CLFSH(void) { return f_1_edx[19]; }
39
+ bool MMX(void) { return f_1_edx[23]; }
40
+ bool FXSR(void) { return f_1_edx[24]; }
41
+ bool SSE(void) { return f_1_edx[25]; }
42
+ bool SSE2(void) { return f_1_edx[26]; }
43
+
44
+ bool FSGSBASE(void) { return f_7_ebx[0]; }
45
+ bool BMI1(void) { return f_7_ebx[3]; }
46
+ bool HLE(void) { return is_intel && f_7_ebx[4]; }
47
+ bool AVX2(void) { return f_7_ebx[5]; }
48
+ bool BMI2(void) { return f_7_ebx[8]; }
49
+ bool ERMS(void) { return f_7_ebx[9]; }
50
+ bool INVPCID(void) { return f_7_ebx[10]; }
51
+ bool RTM(void) { return is_intel && f_7_ebx[11]; }
52
+ bool AVX512F(void) { return f_7_ebx[16]; }
53
+ bool AVX512DQ(void) { return f_7_ebx[17]; }
54
+ bool RDSEED(void) { return f_7_ebx[18]; }
55
+ bool ADX(void) { return f_7_ebx[19]; }
56
+ bool AVX512PF(void) { return f_7_ebx[26]; }
57
+ bool AVX512ER(void) { return f_7_ebx[27]; }
58
+ bool AVX512CD(void) { return f_7_ebx[28]; }
59
+ bool AVX512BW(void) { return f_7_ebx[30]; }
60
+ bool AVX512VL(void) { return f_7_ebx[31]; }
61
+
62
+ bool SHA(void) { return f_7_ebx[29]; }
63
+
64
+ bool PREFETCHWT1(void) { return f_7_ecx[0]; }
65
+
66
+ bool LAHF(void) { return f_81_ecx[0]; }
67
+ bool LZCNT(void) { return is_intel && f_81_ecx[5]; }
68
+ bool ABM(void) { return is_amd && f_81_ecx[5]; }
69
+ bool SSE4a(void) { return is_amd && f_81_ecx[6]; }
70
+ bool XOP(void) { return is_amd && f_81_ecx[11]; }
71
+ bool TBM(void) { return is_amd && f_81_ecx[21]; }
72
+
73
+ bool SYSCALL(void) { return is_intel && f_81_edx[11]; }
74
+ bool MMXEXT(void) { return is_amd && f_81_edx[22]; }
75
+ bool RDTSCP(void) { return is_intel && f_81_edx[27]; }
76
+ bool _3DNOWEXT(void) { return is_amd && f_81_edx[30]; }
77
+ bool _3DNOW(void) { return is_amd && f_81_edx[31]; }
78
+
79
+ bool AVX512_VBMI(void) { return f_7_ecx[1]; }
80
+ bool AVX512_VNNI(void) { return f_7_ecx[11]; }
81
+ bool AVX512_FP16(void) { return f_7_edx[23]; }
82
+ bool AVX512_BF16(void) { return f_7_1_eax[5]; }
83
+ bool AVX_VNNI(void) { return f_7_1_eax[4]; }
84
+
85
+ bool AMX_TILE(void) { return f_7_edx[24]; }
86
+ bool AMX_INT8(void) { return f_7_edx[25]; }
87
+ bool AMX_FP16(void) { return f_7_1_eax[21]; }
88
+ bool AMX_BF16(void) { return f_7_edx[22]; }
89
+
90
+ #ifdef _MSC_VER
91
+ static void cpuid(int cpu_info[4], int eax) {
92
+ __cpuid(cpu_info, eax);
93
+ }
94
+ static void cpuidex(int cpu_info[4], int eax, int ecx) {
95
+ __cpuidex(cpu_info, eax, ecx);
96
+ }
97
+ #else
98
+ static void cpuid(int cpu_info[4], int eax) {
99
+ __asm__ __volatile__(
100
+ "cpuid"
101
+ : "=a"(cpu_info[0]), "=b"(cpu_info[1]), "=c"(cpu_info[2]), "=d"(cpu_info[3])
102
+ : "a"(eax), "c"(0));
103
+ }
104
+ static void cpuidex(int cpu_info[4], int eax, int ecx) {
105
+ __asm__ __volatile__(
106
+ "cpuid"
107
+ : "=a"(cpu_info[0]), "=b"(cpu_info[1]), "=c"(cpu_info[2]), "=d"(cpu_info[3])
108
+ : "a"(eax), "c"(ecx));
109
+ }
110
+ #endif
111
+
112
+ cpuid_x86() {
113
+ std::array<int, 4> cpui;
114
+ std::vector<std::array<int, 4>> data;
115
+
116
+ // calling __cpuid with 0x0 as the function_id argument
117
+ // gets the number of the highest valid function ID.
118
+ cpuid(cpui.data(), 0);
119
+ int n_ids = cpui[0];
120
+
121
+ for (int i = 0; i <= n_ids; ++i) {
122
+ cpuidex(cpui.data(), i, 0);
123
+ data.push_back(cpui);
124
+ }
125
+
126
+ // capture vendor string
127
+ char vendor[0x20] = {};
128
+ *reinterpret_cast<int *>(vendor) = data[0][1];
129
+ *reinterpret_cast<int *>(vendor + 4) = data[0][3];
130
+ *reinterpret_cast<int *>(vendor + 8) = data[0][2];
131
+ this->vendor = vendor;
132
+ if (this->vendor == "GenuineIntel") {
133
+ is_intel = true;
134
+ } else if (this->vendor == "AuthenticAMD") {
135
+ is_amd = true;
136
+ }
137
+
138
+ // load bitset with flags for function 0x00000001
139
+ if (n_ids >= 1) {
140
+ f_1_ecx = data[1][2];
141
+ f_1_edx = data[1][3];
142
+ }
143
+
144
+ // load bitset with flags for function 0x00000007
145
+ if (n_ids >= 7) {
146
+ f_7_ebx = data[7][1];
147
+ f_7_ecx = data[7][2];
148
+ f_7_edx = data[7][3];
149
+ cpuidex(cpui.data(), 7, 1);
150
+ f_7_1_eax = cpui[0];
151
+ }
152
+
153
+ // calling __cpuid with 0x80000000 as the function_id argument
154
+ // gets the number of the highest valid extended ID.
155
+ cpuid(cpui.data(), 0x80000000);
156
+ unsigned int n_ex_ids = cpui[0];
157
+
158
+ std::vector<std::array<int, 4>> ext_data;
159
+ for (unsigned int i = 0x80000000; i <= n_ex_ids; ++i) {
160
+ cpuidex(cpui.data(), i, 0);
161
+ ext_data.push_back(cpui);
162
+ }
163
+
164
+ // load bitset with flags for function 0x80000001
165
+ if (n_ex_ids >= 0x80000001) {
166
+ f_81_ecx = ext_data[1][2];
167
+ f_81_edx = ext_data[1][3];
168
+ }
169
+
170
+ // interpret CPU brand string if reported
171
+ char brand[0x40] = {};
172
+ if (n_ex_ids >= 0x80000004) {
173
+ std::memcpy(brand, ext_data[2].data(), sizeof(cpui));
174
+ std::memcpy(brand + 16, ext_data[3].data(), sizeof(cpui));
175
+ std::memcpy(brand + 32, ext_data[4].data(), sizeof(cpui));
176
+ this->brand = brand;
177
+ }
178
+ }
179
+
180
+ bool is_intel = false;
181
+ bool is_amd = false;
182
+ std::string vendor;
183
+ std::string brand;
184
+ std::bitset<32> f_1_ecx;
185
+ std::bitset<32> f_1_edx;
186
+ std::bitset<32> f_7_ebx;
187
+ std::bitset<32> f_7_ecx;
188
+ std::bitset<32> f_7_edx;
189
+ std::bitset<32> f_7_1_eax;
190
+ std::bitset<32> f_81_ecx;
191
+ std::bitset<32> f_81_edx;
192
+ };
193
+
194
+ #if 0
195
+ void test_x86_is() {
196
+ cpuid_x86 is;
197
+ printf("CPU Vendor: %s\n", is.vendor.c_str());
198
+ printf("Brand: %s\n", is.brand.c_str());
199
+ printf("is_intel: %d\n", is.is_intel);
200
+ printf("is_amd: %d\n", is.is_amd);
201
+ printf("sse3: %d\n", is.SSE3());
202
+ printf("pclmulqdq: %d\n", is.PCLMULQDQ());
203
+ printf("ssse3: %d\n", is.SSSE3());
204
+ printf("fma: %d\n", is.FMA());
205
+ printf("cmpxchg16b: %d\n", is.CMPXCHG16B());
206
+ printf("sse41: %d\n", is.SSE41());
207
+ printf("sse42: %d\n", is.SSE42());
208
+ printf("movbe: %d\n", is.MOVBE());
209
+ printf("popcnt: %d\n", is.POPCNT());
210
+ printf("aes: %d\n", is.AES());
211
+ printf("xsave: %d\n", is.XSAVE());
212
+ printf("osxsave: %d\n", is.OSXSAVE());
213
+ printf("avx: %d\n", is.AVX());
214
+ printf("f16c: %d\n", is.F16C());
215
+ printf("rdrand: %d\n", is.RDRAND());
216
+ printf("msr: %d\n", is.MSR());
217
+ printf("cx8: %d\n", is.CX8());
218
+ printf("sep: %d\n", is.SEP());
219
+ printf("cmov: %d\n", is.CMOV());
220
+ printf("clflush: %d\n", is.CLFSH());
221
+ printf("mmx: %d\n", is.MMX());
222
+ printf("fxsr: %d\n", is.FXSR());
223
+ printf("sse: %d\n", is.SSE());
224
+ printf("sse2: %d\n", is.SSE2());
225
+ printf("fsgsbase: %d\n", is.FSGSBASE());
226
+ printf("bmi1: %d\n", is.BMI1());
227
+ printf("hle: %d\n", is.HLE());
228
+ printf("avx2: %d\n", is.AVX2());
229
+ printf("bmi2: %d\n", is.BMI2());
230
+ printf("erms: %d\n", is.ERMS());
231
+ printf("invpcid: %d\n", is.INVPCID());
232
+ printf("rtm: %d\n", is.RTM());
233
+ printf("avx512f: %d\n", is.AVX512F());
234
+ printf("rdseed: %d\n", is.RDSEED());
235
+ printf("adx: %d\n", is.ADX());
236
+ printf("avx512pf: %d\n", is.AVX512PF());
237
+ printf("avx512er: %d\n", is.AVX512ER());
238
+ printf("avx512cd: %d\n", is.AVX512CD());
239
+ printf("sha: %d\n", is.SHA());
240
+ printf("prefetchwt1: %d\n", is.PREFETCHWT1());
241
+ printf("lahf: %d\n", is.LAHF());
242
+ printf("lzcnt: %d\n", is.LZCNT());
243
+ printf("abm: %d\n", is.ABM());
244
+ printf("sse4a: %d\n", is.SSE4a());
245
+ printf("xop: %d\n", is.XOP());
246
+ printf("tbm: %d\n", is.TBM());
247
+ printf("syscall: %d\n", is.SYSCALL());
248
+ printf("mmxext: %d\n", is.MMXEXT());
249
+ printf("rdtscp: %d\n", is.RDTSCP());
250
+ printf("3dnowext: %d\n", is._3DNOWEXT());
251
+ printf("3dnow: %d\n", is._3DNOW());
252
+ printf("avx512_vbmi: %d\n", is.AVX512_VBMI());
253
+ printf("avx512_vnni: %d\n", is.AVX512_VNNI());
254
+ printf("avx512_fp16: %d\n", is.AVX512_FP16());
255
+ printf("avx512_bf16: %d\n", is.AVX512_BF16());
256
+ printf("amx_tile: %d\n", is.AMX_TILE());
257
+ printf("amx_int8: %d\n", is.AMX_INT8());
258
+ printf("amx_fp16: %d\n", is.AMX_FP16());
259
+ printf("amx_bf16: %d\n", is.AMX_BF16());
260
+ }
261
+ #endif
262
+
263
+ static int ggml_backend_cpu_x86_score() {
264
+ // FIXME: this does not check for OS support
265
+
266
+ int score = 1;
267
+ cpuid_x86 is;
268
+
269
+ #ifdef GGML_FMA
270
+ if (!is.FMA()) { return 0; }
271
+ score += 1;
272
+ #endif
273
+ #ifdef GGML_F16C
274
+ if (!is.F16C()) { return 0; }
275
+ score += 1<<1;
276
+ #endif
277
+ #ifdef GGML_SSE42
278
+ if (!is.SSE42()) { return 0; }
279
+ score += 1<<2;
280
+ #endif
281
+ #ifdef GGML_BMI2
282
+ if (!is.BMI2()) { return 0; }
283
+ score += 1<<3;
284
+ #endif
285
+ #ifdef GGML_AVX
286
+ if (!is.AVX()) { return 0; }
287
+ score += 1<<4;
288
+ #endif
289
+ #ifdef GGML_AVX2
290
+ if (!is.AVX2()) { return 0; }
291
+ score += 1<<5;
292
+ #endif
293
+ #ifdef GGML_AVX_VNNI
294
+ if (!is.AVX_VNNI()) { return 0; }
295
+ score += 1<<6;
296
+ #endif
297
+ #ifdef GGML_AVX512
298
+ if (!is.AVX512F()) { return 0; }
299
+ if (!is.AVX512CD()) { return 0; }
300
+ if (!is.AVX512VL()) { return 0; }
301
+ if (!is.AVX512DQ()) { return 0; }
302
+ if (!is.AVX512BW()) { return 0; }
303
+ score += 1<<7;
304
+ #endif
305
+ #ifdef GGML_AVX512_VBMI
306
+ if (!is.AVX512_VBMI()) { return 0; }
307
+ score += 1<<8;
308
+ #endif
309
+ #ifdef GGML_AVX512_BF16
310
+ if (!is.AVX512_BF16()) { return 0; }
311
+ score += 1<<9;
312
+ #endif
313
+ #ifdef GGML_AVX512_VNNI
314
+ if (!is.AVX512_VNNI()) { return 0; }
315
+ score += 1<<10;
316
+ #endif
317
+ #ifdef GGML_AMX_INT8
318
+ if (!is.AMX_INT8()) { return 0; }
319
+ score += 1<<11;
320
+ #endif
321
+
322
+ return score;
323
+ }
324
+
325
+ GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_x86_score)
326
+
327
+ #endif // defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
ggml/src/ggml-cpu/arch/x86/quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/arch/x86/repack.cpp ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/common.h CHANGED
@@ -1,7 +1,7 @@
1
  #pragma once
2
 
3
  #include "ggml.h"
4
- #include "ggml-cpu-traits.h"
5
  #include "ggml-cpu-impl.h"
6
  #include "ggml-impl.h"
7
 
 
1
  #pragma once
2
 
3
  #include "ggml.h"
4
+ #include "traits.h"
5
  #include "ggml-cpu-impl.h"
6
  #include "ggml-impl.h"
7
 
ggml/src/ggml-cpu/ggml-cpu-impl.h CHANGED
@@ -506,3 +506,25 @@ void ggml_barrier(struct ggml_threadpool * tp);
506
  #ifdef __cplusplus
507
  }
508
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  #ifdef __cplusplus
507
  }
508
  #endif
509
+
510
+ #define GGML_DO_PRAGMA_(x) _Pragma (#x)
511
+ #define GGML_DO_PRAGMA(x) GGML_DO_PRAGMA_(x)
512
+ #if defined(GGML_CPU_GENERIC) || defined(__HIPCC__)
513
+ // Note for Apple targets:
514
+ // - clang: aliases are not supported on darwin
515
+ // - all native kernels need to be implemented in both x86 and arm files
516
+ // - on iOS, tvOS, and visionOS, if cmake cannot determine the target architecture, all `_generic` names are replaced by defines
517
+ # define GGML_WEAK_ALIAS(name, alias)
518
+ #elif defined(__GNUC__)
519
+ // GCC/Clang on *nix
520
+ # define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT
521
+ #elif defined(_MSC_VER) && defined (_WIN64)
522
+ // MSVC
523
+ // Note: C name mangling varies across different calling conventions
524
+ // see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170
525
+ # define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias))
526
+ #else
527
+ # error "Unsupported compiler for GGML_WEAK_ALIAS"
528
+ #endif
529
+
530
+ #define GGML_CPU_NATIVE_IMPL(name) GGML_WEAK_ALIAS(name, name ## _generic)
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -3,11 +3,11 @@
3
 
4
  #include "ggml-backend-impl.h"
5
  #include "ggml-backend.h"
6
- #include "ggml-cpu-traits.h"
7
  #include "ggml-cpu-impl.h"
8
  #include "ggml-cpu.h"
9
  #include "ggml-impl.h"
10
- #include "ggml-cpu-quants.h"
11
  #include "ggml-threading.h"
12
  #include "unary-ops.h"
13
  #include "binary-ops.h"
 
3
 
4
  #include "ggml-backend-impl.h"
5
  #include "ggml-backend.h"
6
+ #include "traits.h"
7
  #include "ggml-cpu-impl.h"
8
  #include "ggml-cpu.h"
9
  #include "ggml-impl.h"
10
+ #include "quants.h"
11
  #include "ggml-threading.h"
12
  #include "unary-ops.h"
13
  #include "binary-ops.h"
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -1,8 +1,8 @@
1
  #include "ggml-backend.h"
2
  #include "ggml-backend-impl.h"
3
  #include "ggml-cpu.h"
4
- #include "ggml-cpu-aarch64.h"
5
- #include "ggml-cpu-traits.h"
6
  #include "ggml-impl.h"
7
  #include "amx/amx.h"
8
 
@@ -11,7 +11,7 @@
11
  #include <vector>
12
 
13
  #ifdef GGML_USE_CPU_HBM
14
- # include "ggml-cpu-hbm.h"
15
  #endif
16
 
17
  #ifdef GGML_USE_CPU_KLEIDIAI
@@ -51,9 +51,9 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
51
  }
52
  #endif
53
 
54
- #ifdef GGML_USE_CPU_AARCH64
55
- if (ggml_backend_cpu_aarch64_buffer_type()) {
56
- bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
57
  }
58
  #endif
59
 
@@ -596,8 +596,8 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
596
  #ifdef GGML_USE_CPU_KLEIDIAI
597
  features.push_back({ "KLEIDIAI", "1" });
598
  #endif
599
- #ifdef GGML_USE_CPU_AARCH64
600
- features.push_back({ "AARCH64_REPACK", "1" });
601
  #endif
602
 
603
  features.push_back({ nullptr, nullptr });
 
1
  #include "ggml-backend.h"
2
  #include "ggml-backend-impl.h"
3
  #include "ggml-cpu.h"
4
+ #include "repack.h"
5
+ #include "traits.h"
6
  #include "ggml-impl.h"
7
  #include "amx/amx.h"
8
 
 
11
  #include <vector>
12
 
13
  #ifdef GGML_USE_CPU_HBM
14
+ # include "hbm.h"
15
  #endif
16
 
17
  #ifdef GGML_USE_CPU_KLEIDIAI
 
51
  }
52
  #endif
53
 
54
+ #ifdef GGML_USE_CPU_REPACK
55
+ if (ggml_backend_cpu_repack_buffer_type()) {
56
+ bufts.push_back(ggml_backend_cpu_repack_buffer_type());
57
  }
58
  #endif
59
 
 
596
  #ifdef GGML_USE_CPU_KLEIDIAI
597
  features.push_back({ "KLEIDIAI", "1" });
598
  #endif
599
+ #ifdef GGML_USE_CPU_REPACK
600
+ features.push_back({ "REPACK", "1" });
601
  #endif
602
 
603
  features.push_back({ nullptr, nullptr });
ggml/src/ggml-cpu/hbm.cpp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef GGML_USE_CPU_HBM
2
+
3
+ #include "ggml-backend.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-cpu.h"
6
+ #include "ggml-impl.h"
7
+
8
+ #include "hbm.h"
9
+
10
+ // buffer type HBM
11
+
12
+ #include <hbwmalloc.h>
13
+
14
+ static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
15
+ return "CPU_HBM";
16
+
17
+ GGML_UNUSED(buft);
18
+ }
19
+
20
+ static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
21
+ hbw_free(buffer->context);
22
+ }
23
+
24
+ static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
25
+ size_t size) {
26
+ void * ptr;
27
+ int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
28
+ if (result != 0) {
29
+ GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
30
+ return NULL;
31
+ }
32
+
33
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
34
+ buffer->buft = buft;
35
+ buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
36
+
37
+ return buffer;
38
+ }
39
+
40
+ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
41
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
42
+ /* .iface = */ {
43
+ /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
44
+ /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
45
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
46
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
47
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
48
+ /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
49
+ },
50
+ /* .context = */ nullptr,
51
+ };
52
+
53
+ return &ggml_backend_cpu_buffer_type_hbm;
54
+ }
55
+ #endif
ggml/src/ggml-cpu/hbm.h ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml-backend.h"
4
+ #include "ggml.h"
5
+
6
+ // GGML CPU internal header
7
+
8
+ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp CHANGED
@@ -26,7 +26,7 @@
26
  #include "ggml-impl.h"
27
  #include "ggml-backend-impl.h"
28
  #include "ggml-threading.h"
29
- #include "ggml-cpu-traits.h"
30
 
31
  #include "kernels.h"
32
 
 
26
  #include "ggml-impl.h"
27
  #include "ggml-backend-impl.h"
28
  #include "ggml-threading.h"
29
+ #include "traits.h"
30
 
31
  #include "kernels.h"
32
 
ggml/src/ggml-cpu/quants.c ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+
4
+ #include "ggml-cpu-impl.h"
5
+ #include "ggml-quants.h"
6
+ #include "quants.h"
7
+
8
+ #include <string.h>
9
+ #include <assert.h>
10
+ #include <float.h>
11
+ #include <stdlib.h> // for qsort
12
+ #include <stdio.h> // for GGML_ASSERT
13
+
14
+ #define GROUP_MAX_EPS 1e-15f
15
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
16
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
17
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
18
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
19
+
20
+ #define UNUSED GGML_UNUSED
21
+
22
+ void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
23
+ quantize_row_q4_0_ref(x, y, k);
24
+ }
25
+
26
+ void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
27
+ quantize_row_q4_1_ref(x, y, k);
28
+ }
29
+
30
+ void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
31
+ quantize_row_q5_0_ref(x, y, k);
32
+ }
33
+
34
+ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
35
+ quantize_row_q5_1_ref(x, y, k);
36
+ }
37
+
38
+ void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
39
+ quantize_row_q8_0_ref(x, y, k);
40
+ }
41
+ GGML_CPU_NATIVE_IMPL(quantize_row_q8_0)
42
+
43
+ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
44
+ quantize_row_q8_1_ref(x, y, k);
45
+ }
46
+ GGML_CPU_NATIVE_IMPL(quantize_row_q8_1)
47
+
48
+ //
49
+ // 2-6 bit quantization in super-blocks
50
+ //
51
+
52
+ //========================- 2-bit (de)-quantization
53
+
54
+ void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
55
+ quantize_row_q2_K_ref(x, vy, k);
56
+ }
57
+
58
+ //========================= 3-bit (de)-quantization
59
+
60
+ void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
61
+ quantize_row_q3_K_ref(x, vy, k);
62
+ }
63
+
64
+ // ====================== 4-bit (de)-quantization
65
+
66
+ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
67
+ assert(k % QK_K == 0);
68
+ block_q4_K * GGML_RESTRICT y = vy;
69
+ quantize_row_q4_K_ref(x, y, k);
70
+ }
71
+
72
+ // ====================== 5-bit (de)-quantization
73
+
74
+ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
75
+ assert(k % QK_K == 0);
76
+ block_q5_K * GGML_RESTRICT y = vy;
77
+ quantize_row_q5_K_ref(x, y, k);
78
+ }
79
+
80
+ // ====================== 6-bit (de)-quantization
81
+
82
+ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
83
+ assert(k % QK_K == 0);
84
+ block_q6_K * GGML_RESTRICT y = vy;
85
+ quantize_row_q6_K_ref(x, y, k);
86
+ }
87
+
88
+ // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
89
+
90
+ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
91
+ assert(k % QK_K == 0);
92
+ block_tq1_0 * GGML_RESTRICT y = vy;
93
+ quantize_row_tq1_0_ref(x, y, k);
94
+ }
95
+
96
+ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
97
+ assert(k % QK_K == 0);
98
+ block_tq2_0 * GGML_RESTRICT y = vy;
99
+ quantize_row_tq2_0_ref(x, y, k);
100
+ }
101
+
102
+ //===================================== Q8_K ==============================================
103
+
104
+ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
105
+ quantize_row_q8_K_ref(x, y, k);
106
+ }
107
+ GGML_CPU_NATIVE_IMPL(quantize_row_q8_K)
108
+
109
+ //===================================== Dot products =================================
110
+
111
+ void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
112
+ const int qk = QK8_0;
113
+ const int nb = n / qk;
114
+
115
+ assert(n % qk == 0);
116
+ assert(nrc == 1);
117
+ UNUSED(nrc);
118
+ UNUSED(bx);
119
+ UNUSED(by);
120
+ UNUSED(bs);
121
+
122
+ const block_q4_0 * GGML_RESTRICT x = vx;
123
+ const block_q8_0 * GGML_RESTRICT y = vy;
124
+
125
+ int ib = 0;
126
+ float sumf = 0;
127
+
128
+ for (; ib < nb; ++ib) {
129
+ int sumi0 = 0;
130
+ int sumi1 = 0;
131
+
132
+ for (int j = 0; j < qk/2; ++j) {
133
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
134
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
135
+
136
+ sumi0 += (v0 * y[ib].qs[j]);
137
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
138
+ }
139
+
140
+ int sumi = sumi0 + sumi1;
141
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
142
+ }
143
+
144
+ *s = sumf;
145
+ }
146
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_0_q8_0)
147
+
148
+ // TODO: add WASM SIMD
149
+ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
150
+ const int qk = QK8_1;
151
+ const int nb = n / qk;
152
+
153
+ assert(n % qk == 0);
154
+ assert(nrc == 1);
155
+ UNUSED(nrc);
156
+ UNUSED(bx);
157
+ UNUSED(by);
158
+ UNUSED(bs);
159
+
160
+ const block_q4_1 * GGML_RESTRICT x = vx;
161
+ const block_q8_1 * GGML_RESTRICT y = vy;
162
+
163
+ int ib = 0;
164
+ float sumf = 0;
165
+
166
+ for (; ib < nb; ++ib) {
167
+ int sumi0 = 0;
168
+ int sumi1 = 0;
169
+
170
+ for (int j = 0; j < qk/2; ++j) {
171
+ const int v0 = (x[ib].qs[j] & 0x0F);
172
+ const int v1 = (x[ib].qs[j] >> 4);
173
+
174
+ sumi0 += (v0 * y[ib].qs[j]);
175
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
176
+ }
177
+
178
+ int sumi = sumi0 + sumi1;
179
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
180
+ }
181
+
182
+ *s = sumf;
183
+ }
184
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_1_q8_1)
185
+
186
+ void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
187
+ const int qk = QK8_0;
188
+ const int nb = n / qk;
189
+
190
+ int ib = 0;
191
+ float sumf = 0;
192
+
193
+ assert(n % qk == 0);
194
+ assert(qk == QK5_0);
195
+ assert(nrc == 1);
196
+ UNUSED(nrc);
197
+ UNUSED(bx);
198
+ UNUSED(by);
199
+ UNUSED(bs);
200
+
201
+ const block_q5_0 * GGML_RESTRICT x = vx;
202
+ const block_q8_0 * GGML_RESTRICT y = vy;
203
+
204
+ for (; ib < nb; ++ib) {
205
+ uint32_t qh;
206
+ memcpy(&qh, x[ib].qh, sizeof(qh));
207
+
208
+ int sumi0 = 0;
209
+ int sumi1 = 0;
210
+
211
+ for (int j = 0; j < qk/2; ++j) {
212
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
213
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
214
+
215
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
216
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
217
+
218
+ sumi0 += (x0 * y[ib].qs[j]);
219
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
220
+ }
221
+
222
+ int sumi = sumi0 + sumi1;
223
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
224
+ }
225
+
226
+ *s = sumf;
227
+ }
228
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_0_q8_0)
229
+
230
+ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
231
+ const int qk = QK8_1;
232
+ const int nb = n / qk;
233
+
234
+ int ib = 0;
235
+ float sumf = 0;
236
+
237
+ assert(n % qk == 0);
238
+ assert(qk == QK5_1);
239
+ assert(nrc == 1);
240
+ UNUSED(nrc);
241
+ UNUSED(bx);
242
+ UNUSED(by);
243
+ UNUSED(bs);
244
+
245
+ const block_q5_1 * GGML_RESTRICT x = vx;
246
+ const block_q8_1 * GGML_RESTRICT y = vy;
247
+
248
+ for (; ib < nb; ++ib) {
249
+ uint32_t qh;
250
+ memcpy(&qh, x[ib].qh, sizeof(qh));
251
+
252
+ int sumi0 = 0;
253
+ int sumi1 = 0;
254
+
255
+ for (int j = 0; j < qk/2; ++j) {
256
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
257
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
258
+
259
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
260
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
261
+
262
+ sumi0 += (x0 * y[ib].qs[j]);
263
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
264
+ }
265
+
266
+ int sumi = sumi0 + sumi1;
267
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
268
+ }
269
+
270
+ *s = sumf;
271
+ }
272
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_1_q8_1)
273
+
274
+ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
275
+ const int qk = QK8_0;
276
+ const int nb = n / qk;
277
+
278
+ assert(n % qk == 0);
279
+ assert(nrc == 1);
280
+ UNUSED(nrc);
281
+ UNUSED(bx);
282
+ UNUSED(by);
283
+ UNUSED(bs);
284
+
285
+ const block_q8_0 * GGML_RESTRICT x = vx;
286
+ const block_q8_0 * GGML_RESTRICT y = vy;
287
+
288
+ int ib = 0;
289
+ float sumf = 0;
290
+
291
+ for (; ib < nb; ++ib) {
292
+ int sumi = 0;
293
+
294
+ for (int j = 0; j < qk; j++) {
295
+ sumi += x[ib].qs[j]*y[ib].qs[j];
296
+ }
297
+
298
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
299
+ }
300
+
301
+ *s = sumf;
302
+ }
303
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q8_0_q8_0)
304
+
305
+ void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
306
+ assert(nrc == 1);
307
+ UNUSED(nrc);
308
+ UNUSED(bx);
309
+ UNUSED(by);
310
+ UNUSED(bs);
311
+
312
+ const block_tq1_0 * GGML_RESTRICT x = vx;
313
+ const block_q8_K * GGML_RESTRICT y = vy;
314
+
315
+ const int nb = n / QK_K;
316
+
317
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
318
+
319
+ float sumf = 0.0f;
320
+
321
+ for (int i = 0; i < nb; ++i) {
322
+ int sum = 0;
323
+
324
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
325
+ for (size_t l = 0; l < 5; ++l) {
326
+ for (size_t m = 0; m < 32; ++m) {
327
+ uint8_t q = x[i].qs[j + m] * pow3[l];
328
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
329
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
330
+ }
331
+ }
332
+ }
333
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
334
+ for (size_t l = 0; l < 5; ++l) {
335
+ for (size_t m = 0; m < 16; ++m) {
336
+ uint8_t q = x[i].qs[j + m] * pow3[l];
337
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
338
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
339
+ }
340
+ }
341
+ }
342
+
343
+ for (size_t l = 0; l < 4; ++l) {
344
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
345
+ uint8_t q = x[i].qh[j] * pow3[l];
346
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
347
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
348
+ }
349
+ }
350
+
351
+ sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
352
+ }
353
+
354
+ *s = sumf;
355
+ }
356
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_tq1_0_q8_K)
357
+
358
+ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
359
+ assert(nrc == 1);
360
+ UNUSED(nrc);
361
+ UNUSED(bx);
362
+ UNUSED(by);
363
+ UNUSED(bs);
364
+
365
+ const block_tq2_0 * GGML_RESTRICT x = vx;
366
+ const block_q8_K * GGML_RESTRICT y = vy;
367
+
368
+ const int nb = n / QK_K;
369
+ float sumf = 0.0f;
370
+
371
+ for (int i = 0; i < nb; ++i) {
372
+ int32_t sumi = 0;
373
+
374
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
375
+ for (size_t l = 0; l < 4; ++l) {
376
+ for (size_t k = 0; k < 32; ++k) {
377
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
378
+ }
379
+ }
380
+ }
381
+
382
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
383
+
384
+ sumf += (float) sumi * d;
385
+ }
386
+
387
+ *s = sumf;
388
+ }
389
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_tq2_0_q8_K)
390
+
391
+ void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
392
+ assert(nrc == 1);
393
+ UNUSED(nrc);
394
+ UNUSED(bx);
395
+ UNUSED(by);
396
+ UNUSED(bs);
397
+
398
+ const block_q2_K * GGML_RESTRICT x = vx;
399
+ const block_q8_K * GGML_RESTRICT y = vy;
400
+
401
+ const int nb = n / QK_K;
402
+
403
+ float sumf = 0;
404
+
405
+ for (int i = 0; i < nb; ++i) {
406
+
407
+ const uint8_t * q2 = x[i].qs;
408
+ const int8_t * q8 = y[i].qs;
409
+ const uint8_t * sc = x[i].scales;
410
+
411
+ int summs = 0;
412
+ for (int j = 0; j < 16; ++j) {
413
+ summs += y[i].bsums[j] * (sc[j] >> 4);
414
+ }
415
+
416
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
417
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
418
+
419
+ int isum = 0;
420
+ int is = 0;
421
+ int d;
422
+ for (int k = 0; k < QK_K/128; ++k) {
423
+ int shift = 0;
424
+ for (int j = 0; j < 4; ++j) {
425
+ d = sc[is++] & 0xF;
426
+ int isuml = 0;
427
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
428
+ isum += d * isuml;
429
+ d = sc[is++] & 0xF;
430
+ isuml = 0;
431
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
432
+ isum += d * isuml;
433
+ shift += 2;
434
+ q8 += 32;
435
+ }
436
+ q2 += 32;
437
+ }
438
+ sumf += dall * isum - dmin * summs;
439
+ }
440
+ *s = sumf;
441
+ }
442
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q2_K_q8_K)
443
+
444
+ void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
445
+ assert(n % QK_K == 0);
446
+ assert(nrc == 1);
447
+ UNUSED(nrc);
448
+ UNUSED(bx);
449
+ UNUSED(by);
450
+ UNUSED(bs);
451
+
452
+ const uint32_t kmask1 = 0x03030303;
453
+ const uint32_t kmask2 = 0x0f0f0f0f;
454
+
455
+ const block_q3_K * GGML_RESTRICT x = vx;
456
+ const block_q8_K * GGML_RESTRICT y = vy;
457
+
458
+ const int nb = n / QK_K;
459
+
460
+ // scalar version
461
+ // This function is written like this so the compiler can manage to vectorize most of it
462
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
463
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
464
+ // The ideal situation would be if we could just write the code once, and the compiler would
465
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
466
+ // write vectorized versions for AVX, ARM_NEON, etc.
467
+
468
+ int8_t aux8[QK_K];
469
+ int16_t aux16[8];
470
+ float sums [8];
471
+ int32_t aux32[8];
472
+ memset(sums, 0, 8*sizeof(float));
473
+
474
+ uint32_t auxs[4];
475
+ const int8_t * scales = (const int8_t*)auxs;
476
+
477
+ float sumf = 0;
478
+ for (int i = 0; i < nb; ++i) {
479
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
480
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
481
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
482
+ memset(aux32, 0, 8*sizeof(int32_t));
483
+ int8_t * GGML_RESTRICT a = aux8;
484
+ uint8_t m = 1;
485
+ for (int j = 0; j < QK_K; j += 128) {
486
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
487
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
488
+ a += 32; m <<= 1;
489
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
490
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
491
+ a += 32; m <<= 1;
492
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
493
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
494
+ a += 32; m <<= 1;
495
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
496
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
497
+ a += 32; m <<= 1;
498
+ q3 += 32;
499
+ }
500
+ a = aux8;
501
+
502
+ memcpy(auxs, x[i].scales, 12);
503
+ uint32_t tmp = auxs[2];
504
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
505
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
506
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
507
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
508
+ for (int j = 0; j < QK_K/16; ++j) {
509
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
510
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
511
+ q8 += 8; a += 8;
512
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
513
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
514
+ q8 += 8; a += 8;
515
+ }
516
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
517
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
518
+ }
519
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
520
+ *s = sumf;
521
+ }
522
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q3_K_q8_K)
523
+
524
+ void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
525
+ assert(n % QK_K == 0);
526
+ assert(nrc == 1);
527
+ UNUSED(nrc);
528
+ UNUSED(bx);
529
+ UNUSED(by);
530
+ UNUSED(bs);
531
+
532
+ const block_q4_K * GGML_RESTRICT x = vx;
533
+ const block_q8_K * GGML_RESTRICT y = vy;
534
+
535
+ const int nb = n / QK_K;
536
+
537
+ static const uint32_t kmask1 = 0x3f3f3f3f;
538
+ static const uint32_t kmask2 = 0x0f0f0f0f;
539
+ static const uint32_t kmask3 = 0x03030303;
540
+
541
+ uint32_t utmp[4];
542
+
543
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
544
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
545
+
546
+ int8_t aux8[QK_K];
547
+ int16_t aux16[8];
548
+ float sums [8];
549
+ int32_t aux32[8];
550
+ memset(sums, 0, 8*sizeof(float));
551
+
552
+ float sumf = 0;
553
+ for (int i = 0; i < nb; ++i) {
554
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
555
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
556
+ memset(aux32, 0, 8*sizeof(int32_t));
557
+ int8_t * GGML_RESTRICT a = aux8;
558
+ for (int j = 0; j < QK_K/64; ++j) {
559
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
560
+ a += 32;
561
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
562
+ a += 32; q4 += 32;
563
+ }
564
+ memcpy(utmp, x[i].scales, 12);
565
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
566
+ const uint32_t uaux = utmp[1] & kmask1;
567
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
568
+ utmp[2] = uaux;
569
+ utmp[0] &= kmask1;
570
+
571
+ int sumi = 0;
572
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
573
+ a = aux8;
574
+ int is = 0;
575
+ for (int j = 0; j < QK_K/32; ++j) {
576
+ int32_t scale = scales[is++];
577
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
578
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
579
+ q8 += 8; a += 8;
580
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
581
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
582
+ q8 += 8; a += 8;
583
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
584
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
585
+ q8 += 8; a += 8;
586
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
587
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
588
+ q8 += 8; a += 8;
589
+ }
590
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
591
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
592
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
593
+ sumf -= dmin * sumi;
594
+ }
595
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
596
+ *s = sumf;
597
+ }
598
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_K_q8_K)
599
+
600
+ void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
601
+ assert(n % QK_K == 0);
602
+ assert(nrc == 1);
603
+ UNUSED(nrc);
604
+ UNUSED(bx);
605
+ UNUSED(by);
606
+ UNUSED(bs);
607
+
608
+ const block_q5_K * GGML_RESTRICT x = vx;
609
+ const block_q8_K * GGML_RESTRICT y = vy;
610
+
611
+ const int nb = n / QK_K;
612
+
613
+ static const uint32_t kmask1 = 0x3f3f3f3f;
614
+ static const uint32_t kmask2 = 0x0f0f0f0f;
615
+ static const uint32_t kmask3 = 0x03030303;
616
+
617
+ uint32_t utmp[4];
618
+
619
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
620
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
621
+
622
+ int8_t aux8[QK_K];
623
+ int16_t aux16[8];
624
+ float sums [8];
625
+ int32_t aux32[8];
626
+ memset(sums, 0, 8*sizeof(float));
627
+
628
+ float sumf = 0;
629
+ for (int i = 0; i < nb; ++i) {
630
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
631
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
632
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
633
+ memset(aux32, 0, 8*sizeof(int32_t));
634
+ int8_t * GGML_RESTRICT a = aux8;
635
+ uint8_t m = 1;
636
+ for (int j = 0; j < QK_K/64; ++j) {
637
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
638
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
639
+ a += 32; m <<= 1;
640
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
641
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
642
+ a += 32; m <<= 1;
643
+ q4 += 32;
644
+ }
645
+ memcpy(utmp, x[i].scales, 12);
646
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
647
+ const uint32_t uaux = utmp[1] & kmask1;
648
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
649
+ utmp[2] = uaux;
650
+ utmp[0] &= kmask1;
651
+
652
+ int sumi = 0;
653
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
654
+ a = aux8;
655
+ int is = 0;
656
+ for (int j = 0; j < QK_K/32; ++j) {
657
+ int32_t scale = scales[is++];
658
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
659
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
660
+ q8 += 8; a += 8;
661
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
662
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
663
+ q8 += 8; a += 8;
664
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
665
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
666
+ q8 += 8; a += 8;
667
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
668
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
669
+ q8 += 8; a += 8;
670
+ }
671
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
672
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
673
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
674
+ sumf -= dmin * sumi;
675
+ }
676
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
677
+ *s = sumf;
678
+ }
679
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_K_q8_K)
680
+
681
+ void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
682
+ assert(n % QK_K == 0);
683
+ assert(nrc == 1);
684
+ UNUSED(nrc);
685
+ UNUSED(bx);
686
+ UNUSED(by);
687
+ UNUSED(bs);
688
+
689
+ const block_q6_K * GGML_RESTRICT x = vx;
690
+ const block_q8_K * GGML_RESTRICT y = vy;
691
+
692
+ const int nb = n / QK_K;
693
+
694
+ int8_t aux8[QK_K];
695
+ int16_t aux16[8];
696
+ float sums [8];
697
+ int32_t aux32[8];
698
+ memset(sums, 0, 8*sizeof(float));
699
+
700
+ float sumf = 0;
701
+ for (int i = 0; i < nb; ++i) {
702
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
703
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
704
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
705
+ memset(aux32, 0, 8*sizeof(int32_t));
706
+ int8_t * GGML_RESTRICT a = aux8;
707
+ for (int j = 0; j < QK_K; j += 128) {
708
+ for (int l = 0; l < 32; ++l) {
709
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
710
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
711
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
712
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
713
+ }
714
+ a += 128;
715
+ q4 += 64;
716
+ qh += 32;
717
+ }
718
+ a = aux8;
719
+ int is = 0;
720
+ for (int j = 0; j < QK_K/16; ++j) {
721
+ int scale = x[i].scales[is++];
722
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
723
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
724
+ q8 += 8; a += 8;
725
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
726
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
727
+ q8 += 8; a += 8;
728
+ }
729
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
730
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
731
+ }
732
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
733
+ *s = sumf;
734
+ }
735
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q6_K_q8_K)
736
+
737
+ void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
738
+ assert(n % QK_K == 0);
739
+ assert(nrc == 1);
740
+ UNUSED(nrc);
741
+ UNUSED(bx);
742
+ UNUSED(by);
743
+ UNUSED(bs);
744
+
745
+ const block_iq2_xxs * GGML_RESTRICT x = vx;
746
+ const block_q8_K * GGML_RESTRICT y = vy;
747
+
748
+ const int nb = n / QK_K;
749
+
750
+ uint32_t aux32[2];
751
+ const uint8_t * aux8 = (const uint8_t *)aux32;
752
+
753
+ float sumf = 0.f;
754
+ for (int i = 0; i < nb; ++i) {
755
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
756
+ const uint16_t * GGML_RESTRICT q2 = x[i].qs;
757
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
758
+ int32_t bsum = 0;
759
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
760
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
761
+ q2 += 4;
762
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
763
+ int32_t sumi = 0;
764
+ for (int l = 0; l < 4; ++l) {
765
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
766
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
767
+ for (int j = 0; j < 8; ++j) {
768
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
769
+ }
770
+ q8 += 8;
771
+ }
772
+ bsum += sumi * ls;
773
+ }
774
+ sumf += d * bsum;
775
+ }
776
+ *s = 0.125f * sumf;
777
+ }
778
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_xxs_q8_K)
779
+
780
+ void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
781
+ assert(n % QK_K == 0);
782
+ assert(nrc == 1);
783
+ UNUSED(nrc);
784
+ UNUSED(bx);
785
+ UNUSED(by);
786
+ UNUSED(bs);
787
+
788
+ const block_iq2_xs * GGML_RESTRICT x = vx;
789
+ const block_q8_K * GGML_RESTRICT y = vy;
790
+
791
+ const int nb = n / QK_K;
792
+
793
+ float sumf = 0.f;
794
+ for (int i = 0; i < nb; ++i) {
795
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
796
+ const uint16_t * GGML_RESTRICT q2 = x[i].qs;
797
+ const uint8_t * GGML_RESTRICT sc = x[i].scales;
798
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
799
+ int32_t bsum = 0;
800
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
801
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
802
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
803
+ int32_t sumi = 0;
804
+ for (int l = 0; l < 2; ++l) {
805
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
806
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
807
+ for (int j = 0; j < 8; ++j) {
808
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
809
+ }
810
+ q8 += 8;
811
+ }
812
+ bsum += sumi * ls1;
813
+ sumi = 0;
814
+ for (int l = 2; l < 4; ++l) {
815
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
816
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
817
+ for (int j = 0; j < 8; ++j) {
818
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
819
+ }
820
+ q8 += 8;
821
+ }
822
+ bsum += sumi * ls2;
823
+ q2 += 4;
824
+ }
825
+ sumf += d * bsum;
826
+ }
827
+ *s = 0.125f * sumf;
828
+ }
829
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_xs_q8_K)
830
+
831
+ void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
832
+ assert(n % QK_K == 0);
833
+ assert(nrc == 1);
834
+ UNUSED(nrc);
835
+ UNUSED(bx);
836
+ UNUSED(by);
837
+ UNUSED(bs);
838
+
839
+ const block_iq2_s * GGML_RESTRICT x = vx;
840
+ const block_q8_K * GGML_RESTRICT y = vy;
841
+
842
+ const int nb = n / QK_K;
843
+
844
+ float sumf = 0;
845
+ for (int i = 0; i < nb; i++) {
846
+
847
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
848
+ const int8_t * q8 = y[i].qs;
849
+ const uint8_t * qs = x[i].qs;
850
+ const uint8_t * qh = x[i].qh;
851
+ const uint8_t * signs = qs + QK_K/8;
852
+
853
+ int bsum = 0;
854
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
855
+ int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
856
+ int ls2 = 1 + 2*(x[i].scales[ib32] >> 4);
857
+ int sumi1 = 0, sumi2 = 0;
858
+ for (int l = 0; l < 2; ++l) {
859
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
860
+ for (int j = 0; j < 8; ++j) {
861
+ sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
862
+ }
863
+ q8 += 8;
864
+ }
865
+ for (int l = 2; l < 4; ++l) {
866
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
867
+ for (int j = 0; j < 8; ++j) {
868
+ sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
869
+ }
870
+ q8 += 8;
871
+ }
872
+ bsum += ls1 * sumi1 + ls2 * sumi2;
873
+ qs += 4;
874
+ signs += 4;
875
+ }
876
+
877
+ sumf += d * bsum;
878
+ }
879
+
880
+ *s = 0.125f * sumf;
881
+ }
882
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_s_q8_K)
883
+
884
+ void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
885
+ assert(n % QK_K == 0);
886
+ assert(nrc == 1);
887
+ UNUSED(nrc);
888
+ UNUSED(bx);
889
+ UNUSED(by);
890
+ UNUSED(bs);
891
+
892
+ const block_iq3_xxs * GGML_RESTRICT x = vx;
893
+ const block_q8_K * GGML_RESTRICT y = vy;
894
+
895
+ const int nb = n / QK_K;
896
+
897
+ uint32_t aux32;
898
+
899
+ float sumf = 0.f;
900
+ for (int i = 0; i < nb; ++i) {
901
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
902
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
903
+ const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
904
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
905
+ int32_t bsum = 0;
906
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
907
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
908
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
909
+ int32_t sumi = 0;
910
+ for (int l = 0; l < 4; ++l) {
911
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
912
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
913
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
914
+ for (int j = 0; j < 4; ++j) {
915
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
916
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
917
+ }
918
+ q8 += 8;
919
+ }
920
+ q3 += 8;
921
+ bsum += sumi * ls;
922
+ }
923
+ sumf += d * bsum;
924
+ }
925
+ *s = 0.25f * sumf;
926
+ }
927
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq3_xxs_q8_K)
928
+
929
+ void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
930
+ assert(n % QK_K == 0);
931
+ assert(nrc == 1);
932
+ UNUSED(nrc);
933
+ UNUSED(bx);
934
+ UNUSED(by);
935
+ UNUSED(bs);
936
+
937
+ const block_iq3_s * GGML_RESTRICT x = vx;
938
+ const block_q8_K * GGML_RESTRICT y = vy;
939
+
940
+ const int nb = n / QK_K;
941
+
942
+ float sumf = 0.f;
943
+ for (int i = 0; i < nb; ++i) {
944
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
945
+ const uint8_t * GGML_RESTRICT qs = x[i].qs;
946
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
947
+ const uint8_t * GGML_RESTRICT signs = x[i].signs;
948
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
949
+ int32_t bsum = 0;
950
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
951
+ const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
952
+ const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
953
+ int32_t sumi = 0;
954
+ for (int l = 0; l < 4; ++l) {
955
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
956
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
957
+ for (int j = 0; j < 4; ++j) {
958
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
959
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
960
+ }
961
+ q8 += 8;
962
+ }
963
+ qs += 8;
964
+ signs += 4;
965
+ bsum += sumi * ls1;
966
+ sumi = 0;
967
+ for (int l = 0; l < 4; ++l) {
968
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
969
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
970
+ for (int j = 0; j < 4; ++j) {
971
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
972
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
973
+ }
974
+ q8 += 8;
975
+ }
976
+ qs += 8;
977
+ signs += 4;
978
+ bsum += sumi * ls2;
979
+ }
980
+ sumf += d * bsum;
981
+ }
982
+ *s = sumf;
983
+ }
984
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq3_s_q8_K)
985
+
986
+ void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
987
+ assert(n % QK_K == 0);
988
+ assert(nrc == 1);
989
+ UNUSED(nrc);
990
+ UNUSED(bx);
991
+ UNUSED(by);
992
+ UNUSED(bs);
993
+
994
+ const block_iq1_s * GGML_RESTRICT x = vx;
995
+ const block_q8_K * GGML_RESTRICT y = vy;
996
+
997
+ const int nb = n / QK_K;
998
+
999
+ float sumf = 0;
1000
+ for (int i = 0; i < nb; i++) {
1001
+
1002
+ const int8_t * q8 = y[i].qs;
1003
+ const uint8_t * qs = x[i].qs;
1004
+ const uint16_t * qh = x[i].qh;
1005
+
1006
+ int sumi = 0, sumi1 = 0;
1007
+ for (int ib = 0; ib < QK_K/32; ++ib) {
1008
+ const int ls = 2*((qh[ib] >> 12) & 7) + 1;
1009
+ const int delta = qh[ib] & 0x8000 ? -1 : 1;
1010
+ int lsum = 0;
1011
+ for (int l = 0; l < 4; ++l) {
1012
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
1013
+ for (int j = 0; j < 8; ++j) {
1014
+ lsum += q8[j] * grid[j];
1015
+ }
1016
+ q8 += 8;
1017
+ }
1018
+ sumi += ls * lsum;
1019
+ sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
1020
+ qs += 4;
1021
+ }
1022
+
1023
+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
1024
+ }
1025
+
1026
+ *s = sumf;
1027
+ }
1028
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq1_s_q8_K)
1029
+
1030
+ void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1031
+ assert(n % QK_K == 0);
1032
+ assert(nrc == 1);
1033
+ UNUSED(nrc);
1034
+ UNUSED(bx);
1035
+ UNUSED(by);
1036
+ UNUSED(bs);
1037
+
1038
+ const block_iq1_m * GGML_RESTRICT x = vx;
1039
+ const block_q8_K * GGML_RESTRICT y = vy;
1040
+
1041
+ const int nb = n / QK_K;
1042
+
1043
+ iq1m_scale_t scale;
1044
+
1045
+ int sum1[2], sum2[2], delta[4];
1046
+
1047
+ float sumf = 0;
1048
+ for (int i = 0; i < nb; i++) {
1049
+
1050
+ const int8_t * q8 = y[i].qs;
1051
+ const uint8_t * qs = x[i].qs;
1052
+ const uint8_t * qh = x[i].qh;
1053
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
1054
+
1055
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
1056
+
1057
+ int sumi1 = 0, sumi2 = 0;
1058
+ for (int ib = 0; ib < QK_K/32; ++ib) {
1059
+ delta[0] = qh[0] & 0x08 ? -1 : 1;
1060
+ delta[1] = qh[0] & 0x80 ? -1 : 1;
1061
+ delta[2] = qh[1] & 0x08 ? -1 : 1;
1062
+ delta[3] = qh[1] & 0x80 ? -1 : 1;
1063
+ sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
1064
+ for (int l = 0; l < 4; ++l) {
1065
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
1066
+ int lsum1 = 0, lsum2 = 0;
1067
+ for (int j = 0; j < 8; ++j) {
1068
+ lsum1 += q8[j] * grid[j];
1069
+ lsum2 += q8[j];
1070
+ }
1071
+ q8 += 8;
1072
+ sum1[l/2] += lsum1;
1073
+ sum2[l/2] += lsum2*delta[l];
1074
+ }
1075
+
1076
+ const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
1077
+ const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
1078
+
1079
+ sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
1080
+ sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
1081
+ qs += 4;
1082
+ qh += 2;
1083
+ }
1084
+
1085
+ sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
1086
+ }
1087
+
1088
+ *s = sumf;
1089
+ }
1090
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq1_m_q8_K)
1091
+
1092
+ void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1093
+ assert(nrc == 1);
1094
+ UNUSED(nrc);
1095
+ UNUSED(bx);
1096
+ UNUSED(by);
1097
+ UNUSED(bs);
1098
+ assert(n % QK4_NL == 0);
1099
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
1100
+
1101
+ const block_iq4_nl * GGML_RESTRICT x = vx;
1102
+ const block_q8_0 * GGML_RESTRICT y = vy;
1103
+
1104
+ const int nb = n / QK4_NL;
1105
+
1106
+ int ib = 0;
1107
+ float sumf = 0;
1108
+
1109
+ for (; ib < nb; ++ib) {
1110
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
1111
+ int sumi1 = 0, sumi2 = 0;
1112
+ for (int j = 0; j < QK4_NL/2; ++j) {
1113
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
1114
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
1115
+ }
1116
+ sumf += d * (sumi1 + sumi2);
1117
+ }
1118
+ *s = sumf;
1119
+ }
1120
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq4_nl_q8_0)
1121
+
1122
+ void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1123
+ assert(nrc == 1);
1124
+ UNUSED(nrc);
1125
+ UNUSED(bx);
1126
+ UNUSED(by);
1127
+ UNUSED(bs);
1128
+ assert(n % QK_K == 0);
1129
+
1130
+ const block_iq4_xs * GGML_RESTRICT x = vx;
1131
+ const block_q8_K * GGML_RESTRICT y = vy;
1132
+
1133
+ const int nb = n / QK_K;
1134
+
1135
+ float sumf = 0;
1136
+ for (int ibl = 0; ibl < nb; ++ibl) {
1137
+ const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
1138
+ uint16_t h = x[ibl].scales_h;
1139
+ const uint8_t * qs = x[ibl].qs;
1140
+ const int8_t * q8 = y[ibl].qs;
1141
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
1142
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
1143
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
1144
+ h >>= 4;
1145
+ const float d1 = d4d8*(ls1 - 32);
1146
+ const float d2 = d4d8*(ls2 - 32);
1147
+ int sumi1 = 0, sumi2 = 0;
1148
+ for (int j = 0; j < 16; ++j) {
1149
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
1150
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
1151
+ }
1152
+ sumf += d1 * (sumi1 + sumi2);
1153
+ qs += 16;
1154
+ q8 += 32;
1155
+ sumi1 = sumi2 = 0;
1156
+ for (int j = 0; j < 16; ++j) {
1157
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
1158
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
1159
+ }
1160
+ sumf += d2 * (sumi1 + sumi2);
1161
+ qs += 16;
1162
+ q8 += 32;
1163
+ }
1164
+ }
1165
+ *s = sumf;
1166
+ }
1167
+ GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq4_xs_q8_K)
1168
+
1169
+ // ============================ 4-bit non-linear quants
1170
+
1171
+ void quantize_row_iq4_nl(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
1172
+ assert(k % QK4_NL == 0);
1173
+ quantize_row_iq4_nl_ref(x, y, k);
1174
+ }
1175
+
1176
+ void quantize_row_iq4_xs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
1177
+ assert(k % QK_K == 0);
1178
+ quantize_iq4_xs(x, y, 1, k, NULL);
1179
+ }
ggml/src/ggml-cpu/quants.h ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_C
4
+ #include "ggml-common.h"
5
+
6
+ #include "ggml.h"
7
+
8
+ // GGML CPU internal header
9
+
10
+ #ifdef __cplusplus
11
+ extern "C" {
12
+ #endif
13
+
14
+ // Quantization
15
+ void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
16
+ void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
17
+ void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
18
+ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
19
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
20
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
21
+
22
+ void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
+ void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
24
+ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
25
+ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
26
+ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
27
+ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
28
+
29
+ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
30
+ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
31
+
32
+ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
33
+ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
34
+
35
+ // Dot product
36
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
37
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
38
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
39
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
40
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
41
+
42
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
43
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
44
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
45
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
46
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
47
+
48
+ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
49
+ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
50
+
51
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
52
+ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
53
+ void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
54
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
55
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
56
+ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
57
+ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
58
+ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
59
+ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
60
+
61
+ // Generic implementation
62
+ void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
63
+ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
64
+ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
65
+ void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
66
+ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
67
+ void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
68
+ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
69
+ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
70
+ void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
71
+ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
72
+ void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
73
+ void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
74
+ void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
75
+ void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
76
+ void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
77
+ void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
78
+ void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
79
+ void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
80
+ void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
81
+ void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
82
+ void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
83
+ void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
84
+ void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
85
+ void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
86
+
87
+ #if defined(GGML_CPU_GENERIC)
88
+ #define quantize_row_q8_0_generic quantize_row_q8_0
89
+ #define quantize_row_q8_1_generic quantize_row_q8_1
90
+ #define quantize_row_q8_K_generic quantize_row_q8_K
91
+ #define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0
92
+ #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
93
+ #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
94
+ #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
95
+ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
96
+ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
97
+ #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
98
+ #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
99
+ #define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K
100
+ #define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K
101
+ #define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K
102
+ #define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K
103
+ #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
104
+ #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
105
+ #define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
106
+ #define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
107
+ #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
108
+ #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
109
+ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
110
+ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
111
+ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
112
+ #endif
113
+
114
+ #ifdef __cplusplus
115
+ }
116
+ #endif
ggml/src/ggml-cpu/repack.cpp ADDED
@@ -0,0 +1,1566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define GGML_COMMON_IMPL_CPP
2
+ #define GGML_COMMON_DECL_CPP
3
+ #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
5
+
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+ #include "ggml-cpu-impl.h"
9
+ #include "traits.h"
10
+
11
+ #include <cmath>
12
+ #include <cstring>
13
+ #include <cassert>
14
+ #include <cstdlib> // for qsort
15
+ #include <cstdio> // for GGML_ASSERT
16
+
17
+ #include "repack.h"
18
+
19
+ #if defined(__GNUC__)
20
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
21
+ #endif
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ static inline int nearest_int(float fval) {
26
+ assert(fabsf(fval) <= 4194303.f);
27
+ float val = fval + 12582912.f;
28
+ int i; memcpy(&i, &val, sizeof(int));
29
+ return (i & 0x007fffff) - 0x00400000;
30
+ }
31
+
32
+ // Functions to create the interleaved data layout formats
33
+
34
+ // interleave 4 block_q4_0s in blocks of blck_size_interleave
35
+ // returns an interleaved block_q4_0x4
36
+ // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
37
+ // first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
38
+ //
39
+ // - in : an array of block_q4_0 pointers
40
+ // - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
41
+ // blck_size_interleave bytes
42
+ // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
43
+ // from bias offset form to pure sign form (this saves subtract
44
+ // operations durin unpacking)
45
+ //
46
+
47
+ extern "C" {
48
+
49
+ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
50
+ assert(QK8_0 == 32);
51
+ assert(k % QK8_0 == 0);
52
+ const int nb = k / QK8_0;
53
+
54
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
55
+
56
+ // scalar
57
+ const int blck_size_interleave = 4;
58
+ float srcv[4][QK8_0];
59
+ float id[4];
60
+
61
+ for (int i = 0; i < nb; i++) {
62
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
63
+ float amax = 0.0f; // absolute max
64
+
65
+ for (int j = 0; j < QK8_0; j++) {
66
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
67
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
68
+ }
69
+
70
+ const float d = amax / ((1 << 7) - 1);
71
+ id[row_iter] = d ? 1.0f / d : 0.0f;
72
+
73
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
74
+ }
75
+
76
+ for (int j = 0; j < QK8_0 * 4; j++) {
77
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
78
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
79
+ src_offset += (j % blck_size_interleave);
80
+
81
+ float x0 = srcv[src_id][src_offset] * id[src_id];
82
+ y[i].qs[j] = roundf(x0);
83
+ }
84
+ }
85
+ }
86
+ GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_0_4x4)
87
+
88
+ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
89
+ assert(QK8_0 == 32);
90
+ assert(k % QK8_0 == 0);
91
+ const int nb = k / QK8_0;
92
+
93
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
94
+
95
+ // scalar
96
+ const int blck_size_interleave = 8;
97
+ float srcv[4][QK8_0];
98
+ float id[4];
99
+
100
+ for (int i = 0; i < nb; i++) {
101
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
102
+ float amax = 0.0f; // absolute max
103
+
104
+ for (int j = 0; j < QK8_0; j++) {
105
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
106
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
107
+ }
108
+
109
+ const float d = amax / ((1 << 7) - 1);
110
+ id[row_iter] = d ? 1.0f / d : 0.0f;
111
+
112
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
113
+ }
114
+
115
+ for (int j = 0; j < QK8_0 * 4; j++) {
116
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
117
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
118
+ src_offset += (j % blck_size_interleave);
119
+
120
+ float x0 = srcv[src_id][src_offset] * id[src_id];
121
+ y[i].qs[j] = roundf(x0);
122
+ }
123
+ }
124
+ }
125
+ GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_0_4x8)
126
+
127
+ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
128
+ assert(QK_K == 256);
129
+ assert(k % QK_K == 0);
130
+ const int nb = k / QK_K;
131
+
132
+ block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
133
+
134
+ // scalar
135
+ const int blck_size_interleave = 8;
136
+ float srcv[4][QK_K];
137
+ float iscale[4];
138
+
139
+ for (int i = 0; i < nb; i++) {
140
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
141
+ float amax = 0.0f; // absolute max
142
+ float max = 0;
143
+
144
+ for (int j = 0; j < QK_K; j++) {
145
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
146
+ // Update the maximum value of the corresponding super block
147
+ if(amax < fabsf(srcv[row_iter][j])) {
148
+ amax = fabsf(srcv[row_iter][j]);
149
+ max = srcv[row_iter][j];
150
+ }
151
+ }
152
+
153
+ iscale[row_iter] = amax ? -127.f/max : 0;
154
+
155
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
156
+ }
157
+
158
+ for (int j = 0; j < QK_K / 4; j++) {
159
+ y[i].bsums[j] = 0;
160
+ }
161
+
162
+ // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
163
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
164
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
165
+ for (int j = 0; j < QK_K * 4; j++) {
166
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
167
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
168
+ src_offset += (j % blck_size_interleave);
169
+ int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
170
+
171
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
172
+ y[i].qs[j] = nearest_int(x0);
173
+ y[i].bsums[index] += y[i].qs[j];
174
+ }
175
+ }
176
+ }
177
+ GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_K_4x8)
178
+
179
+ } // extern "C"
180
+
181
+ template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
182
+ void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
183
+
184
+ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
185
+ assert(nrow == 4);
186
+ UNUSED(nrow);
187
+ ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
188
+ }
189
+
190
+ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
191
+ assert(nrow == 4);
192
+ UNUSED(nrow);
193
+ ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
194
+ }
195
+
196
+ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
197
+ assert(nrow == 4);
198
+ UNUSED(nrow);
199
+ ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
200
+ }
201
+
202
+ extern "C" {
203
+
204
+ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
205
+ const int qk = QK8_0;
206
+ const int nb = n / qk;
207
+ const int ncols_interleaved = 4;
208
+ const int blocklen = 4;
209
+
210
+ assert (n % qk == 0);
211
+ assert (nc % ncols_interleaved == 0);
212
+
213
+ UNUSED(s);
214
+ UNUSED(bs);
215
+ UNUSED(vx);
216
+ UNUSED(vy);
217
+ UNUSED(nr);
218
+ UNUSED(nc);
219
+ UNUSED(nb);
220
+ UNUSED(ncols_interleaved);
221
+ UNUSED(blocklen);
222
+
223
+ float sumf[4];
224
+ int sumi;
225
+
226
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
227
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
228
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
229
+
230
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
231
+ for (int l = 0; l < nb; l++) {
232
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
233
+ for (int j = 0; j < ncols_interleaved; j++) {
234
+ sumi = 0;
235
+ for (int i = 0; i < blocklen; ++i) {
236
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
237
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
238
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
239
+ }
240
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
241
+ }
242
+ }
243
+ }
244
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
245
+ }
246
+ }
247
+ GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_4x4_q8_0)
248
+
249
+ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
250
+ const int qk = QK8_0;
251
+ const int nb = n / qk;
252
+ const int ncols_interleaved = 4;
253
+ const int blocklen = 8;
254
+
255
+ assert (n % qk == 0);
256
+ assert (nc % ncols_interleaved == 0);
257
+
258
+ UNUSED(s);
259
+ UNUSED(bs);
260
+ UNUSED(vx);
261
+ UNUSED(vy);
262
+ UNUSED(nr);
263
+ UNUSED(nc);
264
+ UNUSED(nb);
265
+ UNUSED(ncols_interleaved);
266
+ UNUSED(blocklen);
267
+
268
+ float sumf[4];
269
+ int sumi;
270
+
271
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
272
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
273
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
274
+
275
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
276
+ for (int l = 0; l < nb; l++) {
277
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
278
+ for (int j = 0; j < ncols_interleaved; j++) {
279
+ sumi = 0;
280
+ for (int i = 0; i < blocklen; ++i) {
281
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
282
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
283
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
284
+ }
285
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
286
+ }
287
+ }
288
+ }
289
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
290
+ }
291
+ }
292
+ GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_4x8_q8_0)
293
+
294
+ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
295
+ const int qk = QK8_0;
296
+ const int nb = n / qk;
297
+ const int ncols_interleaved = 8;
298
+ const int blocklen = 8;
299
+
300
+ assert (n % qk == 0);
301
+ assert (nc % ncols_interleaved == 0);
302
+
303
+ UNUSED(s);
304
+ UNUSED(bs);
305
+ UNUSED(vx);
306
+ UNUSED(vy);
307
+ UNUSED(nr);
308
+ UNUSED(nc);
309
+ UNUSED(nb);
310
+ UNUSED(ncols_interleaved);
311
+ UNUSED(blocklen);
312
+
313
+ {
314
+ float sumf[8];
315
+ int sumi;
316
+
317
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
318
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
319
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
320
+
321
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
322
+ for (int l = 0; l < nb; l++) {
323
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
324
+ for (int j = 0; j < ncols_interleaved; j++) {
325
+ sumi = 0;
326
+ for (int i = 0; i < blocklen; ++i) {
327
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
328
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
329
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
330
+ }
331
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
332
+ }
333
+ }
334
+ }
335
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
336
+ }
337
+ }
338
+ }
339
+ GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_8x8_q8_0)
340
+
341
+ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
342
+ const int qk = QK_K;
343
+ const int nb = n / qk;
344
+ const int ncols_interleaved = 8;
345
+ const int blocklen = 8;
346
+ static const uint32_t kmask1 = 0x3f3f3f3f;
347
+ static const uint32_t kmask2 = 0x0f0f0f0f;
348
+ static const uint32_t kmask3 = 0x03030303;
349
+
350
+ assert (n % qk == 0);
351
+ assert (nc % ncols_interleaved == 0);
352
+
353
+ UNUSED(s);
354
+ UNUSED(bs);
355
+ UNUSED(vx);
356
+ UNUSED(vy);
357
+ UNUSED(nr);
358
+ UNUSED(nc);
359
+ UNUSED(nb);
360
+ UNUSED(ncols_interleaved);
361
+ UNUSED(blocklen);
362
+
363
+ float sumf[8];
364
+ float sum_minf[8];
365
+ uint32_t utmp[32];
366
+ int sumi1;
367
+ int sumi2;
368
+ int sumi;
369
+
370
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
371
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
372
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
373
+
374
+ for (int j = 0; j < ncols_interleaved; j++) {
375
+ sumf[j] = 0.0;
376
+ sum_minf[j] = 0.0;
377
+ }
378
+ for (int l = 0; l < nb; l++) {
379
+ for (int sb = 0; sb < 8; sb++) {
380
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
381
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
382
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
383
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
384
+ utmp[sb * 4 + 2] = uaux_0;
385
+ utmp[sb * 4 + 0] &= kmask1;
386
+ }
387
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
388
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
389
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
390
+ for (int j = 0; j < ncols_interleaved; j++) {
391
+ sumi1 = 0;
392
+ sumi2 = 0;
393
+ sumi = 0;
394
+ for (int i = 0; i < blocklen; ++i) {
395
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
396
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
397
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
398
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
399
+ sumi1 = sumi1 * scales_0[j];
400
+ sumi2 = sumi2 * scales_1[j];
401
+ sumi += sumi1 + sumi2;
402
+ }
403
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
404
+ }
405
+ }
406
+ for (int sb = 0; sb < 8; sb++) {
407
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
408
+ for (int j = 0; j < ncols_interleaved; j++) {
409
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
410
+ }
411
+ }
412
+ }
413
+ for (int j = 0; j < ncols_interleaved; j++) {
414
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
415
+ }
416
+ }
417
+ }
418
+ GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_K_8x8_q8_K)
419
+
420
+ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
421
+ const int qk = QK8_0;
422
+ const int nb = n / qk;
423
+ const int ncols_interleaved = 4;
424
+ const int blocklen = 4;
425
+
426
+ assert (n % qk == 0);
427
+ assert (nc % ncols_interleaved == 0);
428
+
429
+ UNUSED(s);
430
+ UNUSED(bs);
431
+ UNUSED(vx);
432
+ UNUSED(vy);
433
+ UNUSED(nr);
434
+ UNUSED(nc);
435
+ UNUSED(nb);
436
+ UNUSED(ncols_interleaved);
437
+ UNUSED(blocklen);
438
+
439
+ {
440
+ float sumf[4];
441
+ int sumi;
442
+
443
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
444
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
445
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
446
+
447
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
448
+ for (int l = 0; l < nb; l++) {
449
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
450
+ for (int j = 0; j < ncols_interleaved; j++) {
451
+ sumi = 0;
452
+ for (int i = 0; i < blocklen; ++i) {
453
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
454
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
455
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
456
+ }
457
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
458
+ }
459
+ }
460
+ }
461
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
462
+ }
463
+ }
464
+ }
465
+ GGML_CPU_NATIVE_IMPL(ggml_gemv_iq4_nl_4x4_q8_0)
466
+
467
+ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
468
+ const int qk = QK8_0;
469
+ const int nb = n / qk;
470
+ const int ncols_interleaved = 4;
471
+ const int blocklen = 4;
472
+
473
+ assert (n % qk == 0);
474
+ assert (nr % 4 == 0);
475
+ assert (nc % ncols_interleaved == 0);
476
+
477
+ UNUSED(s);
478
+ UNUSED(bs);
479
+ UNUSED(vx);
480
+ UNUSED(vy);
481
+ UNUSED(nr);
482
+ UNUSED(nc);
483
+ UNUSED(nb);
484
+ UNUSED(ncols_interleaved);
485
+ UNUSED(blocklen);
486
+
487
+ {
488
+ float sumf[4][4];
489
+ int sumi;
490
+
491
+ for (int y = 0; y < nr / 4; y++) {
492
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
493
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
494
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
495
+ for (int m = 0; m < 4; m++) {
496
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
497
+ }
498
+ for (int l = 0; l < nb; l++) {
499
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
500
+ for (int m = 0; m < 4; m++) {
501
+ for (int j = 0; j < ncols_interleaved; j++) {
502
+ sumi = 0;
503
+ for (int i = 0; i < blocklen; ++i) {
504
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
505
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
506
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
507
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
508
+ }
509
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
510
+ }
511
+ }
512
+ }
513
+ }
514
+ for (int m = 0; m < 4; m++) {
515
+ for (int j = 0; j < ncols_interleaved; j++)
516
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
517
+ }
518
+ }
519
+ }
520
+ }
521
+ }
522
+ GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_4x4_q8_0)
523
+
524
+ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
525
+ const int qk = QK8_0;
526
+ const int nb = n / qk;
527
+ const int ncols_interleaved = 4;
528
+ const int blocklen = 8;
529
+
530
+ assert (n % qk == 0);
531
+ assert (nr % 4 == 0);
532
+ assert (nc % ncols_interleaved == 0);
533
+
534
+ UNUSED(s);
535
+ UNUSED(bs);
536
+ UNUSED(vx);
537
+ UNUSED(vy);
538
+ UNUSED(nr);
539
+ UNUSED(nc);
540
+ UNUSED(nb);
541
+ UNUSED(ncols_interleaved);
542
+ UNUSED(blocklen);
543
+
544
+ float sumf[4][4];
545
+ int sumi;
546
+
547
+ for (int y = 0; y < nr / 4; y++) {
548
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
549
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
550
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
551
+ for (int m = 0; m < 4; m++) {
552
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
553
+ }
554
+ for (int l = 0; l < nb; l++) {
555
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
556
+ for (int m = 0; m < 4; m++) {
557
+ for (int j = 0; j < ncols_interleaved; j++) {
558
+ sumi = 0;
559
+ for (int i = 0; i < blocklen; ++i) {
560
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
561
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
562
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
563
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
564
+ }
565
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
566
+ }
567
+ }
568
+ }
569
+ }
570
+ for (int m = 0; m < 4; m++) {
571
+ for (int j = 0; j < ncols_interleaved; j++)
572
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
573
+ }
574
+ }
575
+ }
576
+ }
577
+ GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_4x8_q8_0)
578
+
579
+ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
580
+ const int qk = QK8_0;
581
+ const int nb = n / qk;
582
+ const int ncols_interleaved = 8;
583
+ const int blocklen = 8;
584
+
585
+ assert (n % qk == 0);
586
+ assert (nr % 4 == 0);
587
+ assert (nc % ncols_interleaved == 0);
588
+
589
+ UNUSED(s);
590
+ UNUSED(bs);
591
+ UNUSED(vx);
592
+ UNUSED(vy);
593
+ UNUSED(nr);
594
+ UNUSED(nc);
595
+ UNUSED(nb);
596
+ UNUSED(ncols_interleaved);
597
+ UNUSED(blocklen);
598
+
599
+ float sumf[4][8];
600
+ int sumi;
601
+
602
+ for (int y = 0; y < nr / 4; y++) {
603
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
604
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
605
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
606
+ for (int m = 0; m < 4; m++) {
607
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
608
+ }
609
+ for (int l = 0; l < nb; l++) {
610
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
611
+ for (int m = 0; m < 4; m++) {
612
+ for (int j = 0; j < ncols_interleaved; j++) {
613
+ sumi = 0;
614
+ for (int i = 0; i < blocklen; ++i) {
615
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
616
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
617
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
618
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
619
+ }
620
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
621
+ }
622
+ }
623
+ }
624
+ }
625
+ for (int m = 0; m < 4; m++) {
626
+ for (int j = 0; j < ncols_interleaved; j++)
627
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
628
+ }
629
+ }
630
+ }
631
+ }
632
+ GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_8x8_q8_0)
633
+
634
+ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
635
+ const int qk = QK_K;
636
+ const int nb = n / qk;
637
+ const int ncols_interleaved = 8;
638
+ const int blocklen = 8;
639
+ static const uint32_t kmask1 = 0x3f3f3f3f;
640
+ static const uint32_t kmask2 = 0x0f0f0f0f;
641
+ static const uint32_t kmask3 = 0x03030303;
642
+
643
+ assert (n % qk == 0);
644
+ assert (nr % 4 == 0);
645
+ assert (nc % ncols_interleaved == 0);
646
+
647
+ UNUSED(s);
648
+ UNUSED(bs);
649
+ UNUSED(vx);
650
+ UNUSED(vy);
651
+ UNUSED(nr);
652
+ UNUSED(nc);
653
+ UNUSED(nb);
654
+ UNUSED(ncols_interleaved);
655
+ UNUSED(blocklen);
656
+
657
+ float sumf[4][8];
658
+ float sum_minf[4][8];
659
+ uint32_t utmp[32];
660
+ int sumi1;
661
+ int sumi2;
662
+ int sumi;
663
+
664
+ for (int y = 0; y < nr / 4; y++) {
665
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
666
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
667
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
668
+ for (int m = 0; m < 4; m++) {
669
+ for (int j = 0; j < ncols_interleaved; j++) {
670
+ sumf[m][j] = 0.0;
671
+ sum_minf[m][j] = 0.0;
672
+ }
673
+ }
674
+ for (int l = 0; l < nb; l++) {
675
+ for (int sb = 0; sb < 8; sb++) {
676
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
677
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
678
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
679
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
680
+ utmp[sb * 4 + 2] = uaux_0;
681
+ utmp[sb * 4 + 0] &= kmask1;
682
+ }
683
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
684
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
685
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
686
+ for (int m = 0; m < 4; m++) {
687
+ for (int j = 0; j < ncols_interleaved; j++) {
688
+ sumi1 = 0;
689
+ sumi2 = 0;
690
+ sumi = 0;
691
+ for (int i = 0; i < blocklen; ++i) {
692
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
693
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
694
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
695
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
696
+ sumi1 = sumi1 * scales_0[j];
697
+ sumi2 = sumi2 * scales_1[j];
698
+ sumi += sumi1 + sumi2;
699
+ }
700
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
701
+ }
702
+ }
703
+ }
704
+ for (int sb = 0; sb < 8; sb++) {
705
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
706
+ for(int m = 0; m < 4; m++) {
707
+ const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
708
+ for(int j = 0; j < ncols_interleaved; j++) {
709
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
710
+ }
711
+ }
712
+ }
713
+ }
714
+ for (int m = 0; m < 4; m++) {
715
+ for (int j = 0; j < ncols_interleaved; j++) {
716
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
717
+ }
718
+ }
719
+ }
720
+ }
721
+ }
722
+ GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_K_8x8_q8_K)
723
+
724
+ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
725
+ const int qk = QK8_0;
726
+ const int nb = n / qk;
727
+ const int ncols_interleaved = 4;
728
+ const int blocklen = 4;
729
+
730
+ assert (n % qk == 0);
731
+ assert (nr % 4 == 0);
732
+ assert (nc % ncols_interleaved == 0);
733
+
734
+ UNUSED(s);
735
+ UNUSED(bs);
736
+ UNUSED(vx);
737
+ UNUSED(vy);
738
+ UNUSED(nr);
739
+ UNUSED(nc);
740
+ UNUSED(nb);
741
+ UNUSED(ncols_interleaved);
742
+ UNUSED(blocklen);
743
+
744
+ {
745
+ float sumf[4][4];
746
+ int sumi;
747
+
748
+ for (int y = 0; y < nr / 4; y++) {
749
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
750
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
751
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
752
+ for (int m = 0; m < 4; m++) {
753
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
754
+ }
755
+ for (int l = 0; l < nb; l++) {
756
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
757
+ for (int m = 0; m < 4; m++) {
758
+ for (int j = 0; j < ncols_interleaved; j++) {
759
+ sumi = 0;
760
+ for (int i = 0; i < blocklen; ++i) {
761
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
762
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
763
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
764
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
765
+ }
766
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
767
+ }
768
+ }
769
+ }
770
+ }
771
+ for (int m = 0; m < 4; m++) {
772
+ for (int j = 0; j < ncols_interleaved; j++)
773
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
774
+ }
775
+ }
776
+ }
777
+ }
778
+ }
779
+ GGML_CPU_NATIVE_IMPL(ggml_gemm_iq4_nl_4x4_q8_0)
780
+
781
+ } // extern "C"
782
+
783
+ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
784
+ block_q4_0x4 out;
785
+
786
+ for (int i = 0; i < 4; i++) {
787
+ out.d[i] = in[i].d;
788
+ }
789
+
790
+ const int end = QK4_0 * 2 / blck_size_interleave;
791
+
792
+ if (blck_size_interleave == 8) {
793
+ const uint64_t xor_mask = 0x8888888888888888ULL;
794
+ for (int i = 0; i < end; ++i) {
795
+ int src_id = i % 4;
796
+ int src_offset = (i / 4) * blck_size_interleave;
797
+ int dst_offset = i * blck_size_interleave;
798
+
799
+ uint64_t elems;
800
+ // Using memcpy to avoid unaligned memory accesses
801
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
802
+ elems ^= xor_mask;
803
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
804
+ }
805
+ } else if (blck_size_interleave == 4) {
806
+ const uint32_t xor_mask = 0x88888888;
807
+ for (int i = 0; i < end; ++i) {
808
+ int src_id = i % 4;
809
+ int src_offset = (i / 4) * blck_size_interleave;
810
+ int dst_offset = i * blck_size_interleave;
811
+
812
+ uint32_t elems;
813
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
814
+ elems ^= xor_mask;
815
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
816
+ }
817
+ } else {
818
+ GGML_ASSERT(false);
819
+ }
820
+
821
+ return out;
822
+ }
823
+
824
+ // interleave 8 block_q4_0s in blocks of blck_size_interleave
825
+ // returns an interleaved block_q4_0x8
826
+ // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
827
+ // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
828
+ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
829
+ block_q4_0x8 out;
830
+
831
+ for (int i = 0; i < 8; i++) {
832
+ out.d[i] = in[i].d;
833
+ }
834
+
835
+ const int end = QK4_0 * 4 / blck_size_interleave;
836
+ const uint64_t xor_mask = 0x8888888888888888ULL;
837
+
838
+ for (int i = 0; i < end; ++i) {
839
+ int src_id = i % 8;
840
+ int src_offset = (i / 8) * blck_size_interleave;
841
+ int dst_offset = i * blck_size_interleave;
842
+
843
+ uint64_t elems;
844
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
845
+ elems ^= xor_mask;
846
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
847
+ }
848
+
849
+ return out;
850
+ }
851
+
852
+ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
853
+ block_q4_Kx8 out;
854
+ //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
855
+ for (int i = 0; i < 8; i++) {
856
+ out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
857
+ }
858
+
859
+ for (int i = 0; i < 8; i++) {
860
+ out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
861
+ }
862
+
863
+ const int end = QK_K * 4 / blck_size_interleave;
864
+
865
+ // Interleave Q4_K quants by taking 8 bytes at a time
866
+ for (int i = 0; i < end; ++i) {
867
+ int src_id = i % 8;
868
+ int src_offset = (i / 8) * blck_size_interleave;
869
+ int dst_offset = i * blck_size_interleave;
870
+
871
+ uint64_t elems;
872
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
873
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
874
+ }
875
+
876
+ // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
877
+ // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
878
+ // The output Q4_Kx8 structure has 96 bytes
879
+ // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
880
+ // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
881
+ uint8_t s[8], m[8];
882
+
883
+ for (int i = 0; i < 4; i++) {
884
+ for (int j = 0; j < 8; j++) {
885
+ s[j] = in[j].scales[i] & 63;
886
+ m[j] = in[j].scales[i + 4] & 63;
887
+ }
888
+
889
+ out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
890
+ out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
891
+ out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
892
+ out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
893
+ out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
894
+ out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
895
+ out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
896
+ out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
897
+ out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
898
+ out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
899
+ out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
900
+ out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
901
+
902
+ }
903
+
904
+ for (int i = 0; i < 4; i++) {
905
+ for (int j = 0; j < 8; j++) {
906
+ s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
907
+ m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
908
+ }
909
+
910
+ out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
911
+ out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
912
+ out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
913
+ out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
914
+ out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
915
+ out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
916
+ out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
917
+ out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
918
+ out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
919
+ out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
920
+ out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
921
+ out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
922
+
923
+ }
924
+
925
+ return out;
926
+ }
927
+
928
+ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
929
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
930
+ GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
931
+ constexpr int nrows_interleaved = 4;
932
+
933
+ block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
934
+ const block_q4_0 * src = (const block_q4_0 *)data;
935
+ block_q4_0 dst_tmp[4];
936
+ int nrow = ggml_nrows(t);
937
+ int nblocks = t->ne[0] / QK4_0;
938
+
939
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
940
+
941
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
942
+ return -1;
943
+ }
944
+
945
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
946
+ for (int64_t x = 0; x < nblocks; x++) {
947
+ for (int i = 0; i < nrows_interleaved; i++) {
948
+ dst_tmp[i] = src[x + i * nblocks];
949
+ }
950
+ *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
951
+ }
952
+ src += nrows_interleaved * nblocks;
953
+ }
954
+ return 0;
955
+
956
+ GGML_UNUSED(data_size);
957
+ }
958
+ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
959
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
960
+ GGML_ASSERT(interleave_block == 8);
961
+ constexpr int nrows_interleaved = 8;
962
+
963
+ block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
964
+ const block_q4_K * src = (const block_q4_K*) data;
965
+ block_q4_K dst_tmp[8];
966
+ int nrow = ggml_nrows(t);
967
+ int nblocks = t->ne[0] / QK_K;
968
+
969
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
970
+
971
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
972
+ return -1;
973
+ }
974
+
975
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
976
+ for (int64_t x = 0; x < nblocks; x++) {
977
+ for (int i = 0; i < nrows_interleaved; i++ ) {
978
+ dst_tmp[i] = src[x + i * nblocks];
979
+ }
980
+ *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
981
+ }
982
+ src += nrows_interleaved * nblocks;
983
+ }
984
+ return 0;
985
+
986
+ GGML_UNUSED(data_size);
987
+ }
988
+
989
+ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
990
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
991
+ GGML_ASSERT(interleave_block == 8);
992
+ constexpr int nrows_interleaved = 8;
993
+
994
+ block_q4_0x8 * dst = (block_q4_0x8*)t->data;
995
+ const block_q4_0 * src = (const block_q4_0*) data;
996
+ block_q4_0 dst_tmp[8];
997
+ int nrow = ggml_nrows(t);
998
+ int nblocks = t->ne[0] / QK4_0;
999
+
1000
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
1001
+
1002
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1003
+ return -1;
1004
+ }
1005
+
1006
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
1007
+ for (int64_t x = 0; x < nblocks; x++) {
1008
+ for (int i = 0; i < nrows_interleaved; i++ ) {
1009
+ dst_tmp[i] = src[x + i * nblocks];
1010
+ }
1011
+ *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
1012
+ }
1013
+ src += nrows_interleaved * nblocks;
1014
+ }
1015
+ return 0;
1016
+
1017
+ GGML_UNUSED(data_size);
1018
+ }
1019
+
1020
+ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
1021
+ block_iq4_nlx4 out;
1022
+
1023
+ for (int i = 0; i < 4; i++) {
1024
+ out.d[i] = in[i].d;
1025
+ }
1026
+
1027
+ const int end = QK4_NL * 2 / blck_size_interleave;
1028
+
1029
+ // TODO: this branch seems wrong
1030
+ //if (blck_size_interleave == 8) {
1031
+ // for (int i = 0; i < end; ++i) {
1032
+ // int src_id = i % 4;
1033
+ // int src_offset = (i / 4) * blck_size_interleave;
1034
+ // int dst_offset = i * blck_size_interleave;
1035
+
1036
+ // // Using memcpy to avoid unaligned memory accesses
1037
+ // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
1038
+ // }
1039
+ //} else
1040
+ if (blck_size_interleave == 4) {
1041
+ for (int i = 0; i < end; ++i) {
1042
+ int src_id = i % 4;
1043
+ int src_offset = (i / 4) * blck_size_interleave;
1044
+ int dst_offset = i * blck_size_interleave;
1045
+
1046
+ memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
1047
+ }
1048
+ } else {
1049
+ GGML_ASSERT(false);
1050
+ }
1051
+
1052
+ return out;
1053
+ }
1054
+
1055
+ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1056
+ GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
1057
+ //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
1058
+ GGML_ASSERT(interleave_block == 4);
1059
+
1060
+ block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
1061
+ const block_iq4_nl * src = (const block_iq4_nl *)data;
1062
+ block_iq4_nl dst_tmp[4];
1063
+ int nrow = ggml_nrows(t);
1064
+ int nrows_interleaved = 4;
1065
+ int nblocks = t->ne[0] / QK4_0;
1066
+
1067
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
1068
+
1069
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1070
+ return -1;
1071
+ }
1072
+
1073
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
1074
+ for (int64_t x = 0; x < nblocks; x++) {
1075
+ for (int i = 0; i < nrows_interleaved; i++) {
1076
+ dst_tmp[i] = src[x + i * nblocks];
1077
+ }
1078
+ *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
1079
+ }
1080
+ src += nrows_interleaved * nblocks;
1081
+ }
1082
+ return 0;
1083
+
1084
+ GGML_UNUSED(data_size);
1085
+ }
1086
+
1087
+ namespace ggml::cpu::repack {
1088
+ // repack
1089
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
1090
+ int repack(struct ggml_tensor *, const void *, size_t);
1091
+
1092
+ // TODO: generalise.
1093
+ template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1094
+ return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
1095
+ }
1096
+
1097
+ template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1098
+ return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
1099
+ }
1100
+
1101
+ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1102
+ return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
1103
+ }
1104
+
1105
+ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1106
+ return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
1107
+ }
1108
+
1109
+ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1110
+ return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
1111
+ }
1112
+
1113
+ // TODO: needs to be revisited
1114
+ //template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1115
+ // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
1116
+ //}
1117
+
1118
+ // gemv
1119
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1120
+ void gemv(int, float *, size_t, const void *, const void *, int, int);
1121
+
1122
+ template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1123
+ ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1124
+ }
1125
+
1126
+ template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1127
+ ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1128
+ }
1129
+
1130
+ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1131
+ ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1132
+ }
1133
+
1134
+ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1135
+ ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1136
+ }
1137
+
1138
+ template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1139
+ ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1140
+ }
1141
+
1142
+ // gemm
1143
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1144
+ void gemm(int, float *, size_t, const void *, const void *, int, int);
1145
+
1146
+ template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1147
+ ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1148
+ }
1149
+
1150
+ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1151
+ ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1152
+ }
1153
+
1154
+ template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1155
+ ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1156
+ }
1157
+
1158
+ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1159
+ ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1160
+ }
1161
+
1162
+ template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1163
+ ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1164
+ }
1165
+
1166
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
1167
+ public:
1168
+ virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
1169
+ };
1170
+
1171
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
1172
+
1173
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
1174
+ // not realy a GGML_TYPE_Q8_0 but same size.
1175
+ switch (op->op) {
1176
+ case GGML_OP_MUL_MAT:
1177
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1178
+ return true;
1179
+ case GGML_OP_MUL_MAT_ID:
1180
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1181
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
1182
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
1183
+ return true;
1184
+ default:
1185
+ // GGML_ABORT("fatal error");
1186
+ break;
1187
+ }
1188
+ return false;
1189
+ }
1190
+
1191
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
1192
+ switch (op->op) {
1193
+ case GGML_OP_MUL_MAT:
1194
+ forward_mul_mat(params, op);
1195
+ return true;
1196
+ case GGML_OP_MUL_MAT_ID:
1197
+ forward_mul_mat_id(params, op);
1198
+ return true;
1199
+ default:
1200
+ // GGML_ABORT("fatal error");
1201
+ break;
1202
+ }
1203
+ return false;
1204
+ }
1205
+
1206
+ void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
1207
+ const ggml_tensor * src0 = op->src[0];
1208
+ const ggml_tensor * src1 = op->src[1];
1209
+ ggml_tensor * dst = op;
1210
+
1211
+ GGML_TENSOR_BINARY_OP_LOCALS
1212
+
1213
+ const int ith = params->ith;
1214
+ const int nth = params->nth;
1215
+
1216
+ GGML_ASSERT(ne0 == ne01);
1217
+ GGML_ASSERT(ne1 == ne11);
1218
+ GGML_ASSERT(ne2 == ne12);
1219
+ GGML_ASSERT(ne3 == ne13);
1220
+
1221
+ // dst cannot be transposed or permuted
1222
+ GGML_ASSERT(nb0 == sizeof(float));
1223
+ GGML_ASSERT(nb0 <= nb1);
1224
+ GGML_ASSERT(nb1 <= nb2);
1225
+ GGML_ASSERT(nb2 <= nb3);
1226
+
1227
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1228
+
1229
+ GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
1230
+ // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
1231
+
1232
+ char * wdata = static_cast<char *>(params->wdata);
1233
+ const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
1234
+
1235
+ assert(params->wsize >= nbw1 * ne11);
1236
+
1237
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
1238
+
1239
+ int64_t i11_processed = 0;
1240
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1241
+ ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1242
+ }
1243
+
1244
+ i11_processed = ne11 - ne11 % 4;
1245
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1246
+ from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1247
+ }
1248
+
1249
+ ggml_barrier(params->threadpool);
1250
+
1251
+ const void * src1_wdata = params->wdata;
1252
+ const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1253
+ int64_t src0_start = (ith * ne01) / nth;
1254
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
1255
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1256
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1257
+ if (src0_start >= src0_end) {
1258
+ return;
1259
+ }
1260
+
1261
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1262
+ if (ne11 > 3) {
1263
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1264
+ (float *) ((char *) dst->data) + src0_start, ne01,
1265
+ (const char *) src0->data + src0_start * nb01,
1266
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1267
+ }
1268
+ for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1269
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1270
+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1271
+ (const char *) src0->data + src0_start * nb01,
1272
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
1273
+ src0_end - src0_start);
1274
+ }
1275
+ }
1276
+
1277
+ void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
1278
+ const ggml_tensor * src0 = op->src[0];
1279
+ const ggml_tensor * src1 = op->src[1];
1280
+ const ggml_tensor * ids = op->src[2];
1281
+ ggml_tensor * dst = op;
1282
+
1283
+ GGML_TENSOR_BINARY_OP_LOCALS
1284
+
1285
+ const int ith = params->ith;
1286
+ const int nth = params->nth;
1287
+
1288
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
1289
+
1290
+ // we don't support permuted src0 or src1
1291
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
1292
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
1293
+
1294
+ // dst cannot be transposed or permuted
1295
+ GGML_ASSERT(nb0 == sizeof(float));
1296
+ GGML_ASSERT(nb0 <= nb1);
1297
+ GGML_ASSERT(nb1 <= nb2);
1298
+ GGML_ASSERT(nb2 <= nb3);
1299
+
1300
+ GGML_ASSERT(ne03 == 1);
1301
+ GGML_ASSERT(ne13 == 1);
1302
+ GGML_ASSERT(ne3 == 1);
1303
+
1304
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1305
+
1306
+ // row groups
1307
+ const int n_ids = ids->ne[0]; // n_expert_used
1308
+ const int n_as = ne02; // n_expert
1309
+
1310
+ const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
1311
+ const size_t nbw2 = nbw1*ne11;
1312
+ const size_t nbw3 = nbw2*ne12;
1313
+
1314
+ struct mmid_row_mapping {
1315
+ int32_t i1;
1316
+ int32_t i2;
1317
+ };
1318
+
1319
+ GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
1320
+ n_as * ne12 * sizeof(mmid_row_mapping)));
1321
+
1322
+ auto * wdata = (char *) params->wdata;
1323
+ auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
1324
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1325
+
1326
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1327
+
1328
+ // src1: float32 => param type
1329
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
1330
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
1331
+ from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
1332
+ (void *) (wdata + i12 * nbw2 + i11 * nbw1),
1333
+ ne10);
1334
+ }
1335
+ }
1336
+
1337
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
1338
+
1339
+ if (ith == 0) {
1340
+ // initialize matrix_row_counts
1341
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
1342
+
1343
+ // group rows by src0 matrix
1344
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
1345
+ for (int32_t id = 0; id < n_ids; ++id) {
1346
+ const int32_t i02 =
1347
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
1348
+
1349
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
1350
+
1351
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
1352
+ matrix_row_counts[i02] += 1;
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ ggml_barrier(params->threadpool);
1358
+
1359
+ // compute each matrix multiplication in sequence
1360
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
1361
+ const int64_t cne1 = matrix_row_counts[cur_a];
1362
+
1363
+ if (cne1 == 0) {
1364
+ continue;
1365
+ }
1366
+
1367
+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
1368
+
1369
+ //const int64_t nr0 = ne01; // src0 rows
1370
+ const int64_t nr1 = cne1; // src1 rows
1371
+
1372
+ int64_t src0_cur_start = (ith * ne01) / nth;
1373
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1374
+
1375
+ src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
1376
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1377
+
1378
+ if (src0_cur_start >= src0_cur_end) {
1379
+ return;
1380
+ }
1381
+
1382
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
1383
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
1384
+
1385
+ const int id = row_mapping.i1; // selected expert index
1386
+
1387
+ const int64_t i11 = id % ne11;
1388
+ const int64_t i12 = row_mapping.i2; // row index in src1
1389
+
1390
+ const int64_t i1 = id; // selected expert index
1391
+ const int64_t i2 = i12; // row
1392
+
1393
+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
1394
+
1395
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1396
+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
1397
+ src0_cur + src0_cur_start * nb01,
1398
+ src1_col, 1, src0_cur_end - src0_cur_start);
1399
+ }
1400
+ }
1401
+ #undef MMID_MATRIX_ROW
1402
+ }
1403
+
1404
+ int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
1405
+ GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
1406
+ (int) NB_COLS, (int) INTER_SIZE);
1407
+ return ggml::cpu::repack::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
1408
+ }
1409
+ };
1410
+
1411
+ // instance for Q4
1412
+ static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1413
+ static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1414
+ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
1415
+ static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1416
+
1417
+ // instance for IQ4
1418
+ static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1419
+
1420
+ } // namespace ggml::cpu::repack
1421
+
1422
+ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
1423
+ if (cur->type == GGML_TYPE_Q4_0) {
1424
+ if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
1425
+ if (cur->ne[1] % 8 == 0) {
1426
+ return &ggml::cpu::repack::q4_0_8x8_q8_0;
1427
+ }
1428
+ }
1429
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
1430
+ if (cur->ne[1] % 4 == 0) {
1431
+ return &ggml::cpu::repack::q4_0_4x8_q8_0;
1432
+ }
1433
+ }
1434
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1435
+ if (cur->ne[1] % 4 == 0) {
1436
+ return &ggml::cpu::repack::q4_0_4x4_q8_0;
1437
+ }
1438
+ }
1439
+ } else if (cur->type == GGML_TYPE_Q4_K) {
1440
+ if (ggml_cpu_has_avx2()) {
1441
+ if (cur->ne[1] % 8 == 0) {
1442
+ return &ggml::cpu::repack::q4_K_8x8_q8_K;
1443
+ }
1444
+ }
1445
+ } else if (cur->type == GGML_TYPE_IQ4_NL) {
1446
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1447
+ if (cur->ne[1] % 4 == 0) {
1448
+ return &ggml::cpu::repack::iq4_nl_4x4_q8_0;
1449
+ }
1450
+ }
1451
+ }
1452
+
1453
+ return nullptr;
1454
+ }
1455
+
1456
+ static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
1457
+ tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type(tensor));
1458
+
1459
+ GGML_UNUSED(buffer);
1460
+ return GGML_STATUS_SUCCESS;
1461
+ }
1462
+
1463
+ static void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
1464
+ const void * data, size_t offset, size_t size) {
1465
+ GGML_ASSERT(offset == 0);
1466
+ GGML_ASSERT(size == ggml_nbytes(tensor));
1467
+
1468
+ auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra;
1469
+ auto OK = tensor_traits->repack(tensor, data, size);
1470
+
1471
+ GGML_ASSERT(OK == 0);
1472
+ GGML_UNUSED(buffer);
1473
+ }
1474
+
1475
+ static const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1476
+ return "CPU_REPACK";
1477
+
1478
+ GGML_UNUSED(buft);
1479
+ }
1480
+
1481
+ static ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1482
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1483
+
1484
+ if (buffer == nullptr) {
1485
+ return nullptr;
1486
+ }
1487
+
1488
+ buffer->buft = buft;
1489
+ buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor;
1490
+ buffer->iface.set_tensor = ggml_backend_cpu_repack_buffer_set_tensor;
1491
+ buffer->iface.get_tensor = nullptr;
1492
+ buffer->iface.cpy_tensor = nullptr;
1493
+ return buffer;
1494
+ }
1495
+
1496
+ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1497
+ return TENSOR_ALIGNMENT;
1498
+
1499
+ GGML_UNUSED(buft);
1500
+ }
1501
+
1502
+ namespace ggml::cpu::repack {
1503
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1504
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1505
+ if ( op->op == GGML_OP_MUL_MAT &&
1506
+ op->src[0]->buffer &&
1507
+ (ggml_n_dims(op->src[0]) == 2) &&
1508
+ op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() &&
1509
+ ggml_repack_get_optimal_repack_type(op->src[0])
1510
+ ) {
1511
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
1512
+ return false;
1513
+ }
1514
+ if (op->src[1]->type == GGML_TYPE_F32) {
1515
+ return true;
1516
+ }
1517
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
1518
+ // return true;
1519
+ //}
1520
+ // may be possible if Q8_0 packed...
1521
+ } else if (op->op == GGML_OP_MUL_MAT_ID
1522
+ && op->src[0]->buffer
1523
+ && (ggml_n_dims(op->src[0]) == 3)
1524
+ && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
1525
+ && ggml_repack_get_optimal_repack_type(op->src[0])
1526
+ ) {
1527
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
1528
+ return false;
1529
+ }
1530
+ if (op->src[1]->type == GGML_TYPE_F32) {
1531
+ return true;
1532
+ }
1533
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
1534
+ // return true;
1535
+ //}
1536
+ }
1537
+ return false;
1538
+ }
1539
+
1540
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
1541
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
1542
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
1543
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
1544
+ }
1545
+ }
1546
+ return nullptr;
1547
+ }
1548
+ };
1549
+ } // namespace ggml::cpu::repack
1550
+
1551
+ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
1552
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = {
1553
+ /* .iface = */ {
1554
+ /* .get_name = */ ggml_backend_cpu_repack_buffer_type_get_name,
1555
+ /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
1556
+ /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
1557
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
1558
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
1559
+ /* .is_host = */ nullptr,
1560
+ },
1561
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
1562
+ /* .context = */ new ggml::cpu::repack::extra_buffer_type(),
1563
+ };
1564
+
1565
+ return &ggml_backend_cpu_buffer_type_repack;
1566
+ }
ggml/src/ggml-cpu/repack.h ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_CPP
4
+ #include "ggml-common.h"
5
+
6
+ #include "traits.h"
7
+ #include "ggml.h"
8
+
9
+ // GGML internal header
10
+
11
+ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
12
+
13
+ template <int K> constexpr int QK_0() {
14
+ if constexpr (K == 4) {
15
+ return QK4_0;
16
+ }
17
+ if constexpr (K == 8) {
18
+ return QK8_0;
19
+ }
20
+ return -1;
21
+ }
22
+
23
+ template <int K, int N> struct block {
24
+ ggml_half d[N]; // deltas for N qK_0 blocks
25
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
26
+ };
27
+
28
+ // control size
29
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
30
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
31
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
32
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
33
+
34
+ using block_q4_0x4 = block<4, 4>;
35
+ using block_q4_0x8 = block<4, 8>;
36
+ using block_q8_0x4 = block<8, 4>;
37
+ using block_q8_0x8 = block<8, 8>;
38
+
39
+ struct block_q4_Kx8 {
40
+ ggml_half d[8]; // super-block scale for quantized scales
41
+ ggml_half dmin[8]; // super-block scale for quantized mins
42
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
43
+ uint8_t qs[1024]; // 4--bit quants
44
+ };
45
+
46
+ static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
47
+
48
+ struct block_q8_Kx4 {
49
+ float d[4]; // delta
50
+ int8_t qs[QK_K * 4]; // quants
51
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
52
+ };
53
+
54
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
55
+
56
+ struct block_iq4_nlx4 {
57
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
58
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
59
+ };
60
+
61
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
62
+
63
+ #if defined(__cplusplus)
64
+ extern "C" {
65
+ #endif
66
+
67
+ // Workaround for clang:
68
+ // clang++ complains: ``error: call to 'ggml_gemm_q4_0_4x4_q8_0' is ambiguous''
69
+ // repro: https://godbolt.org/z/oKdeWKonM (ICE), https://godbolt.org/z/1szq6P36v (ambiguous call)
70
+ #if defined(GGML_CPU_CLANG_WORKAROUND) || !(defined(__GNUC__) && defined(__clang__)) || defined(__HIPCC__)
71
+ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
72
+ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
73
+ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
74
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
75
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
76
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
77
+ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
78
+ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
79
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
80
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
81
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
82
+ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
83
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
84
+ #endif // !defined(__clang__)
85
+
86
+ // Native implementations
87
+ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
88
+ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
89
+ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
90
+ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
91
+ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
92
+ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
93
+ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
94
+ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
95
+ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
96
+ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
97
+ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
98
+ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
99
+ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
100
+
101
+ #if defined(GGML_CPU_GENERIC)
102
+ #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
103
+ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
104
+ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
105
+ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
106
+ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
107
+ #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
108
+ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
109
+ #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
110
+ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
111
+ #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
112
+ #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
113
+ #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
114
+ #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
115
+ #endif
116
+
117
+ #if defined(__cplusplus)
118
+ } // extern "C"
119
+ #endif
ggml/src/ggml-cpu/traits.cpp ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "traits.h"
2
+
3
+ #include "ggml-backend-impl.h"
4
+ #include "ggml-backend.h"
5
+
6
+ namespace ggml::cpu {
7
+ tensor_traits::~tensor_traits() {}
8
+
9
+ extra_buffer_type::~extra_buffer_type() {}
10
+ } // namespace ggml::cpu
11
+
12
+ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
13
+ for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
14
+ if (extra && extra->context) {
15
+ auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
16
+ auto tensor_traits = buf_extra->get_tensor_traits(op);
17
+ if (tensor_traits && tensor_traits->compute_forward(params, op)) {
18
+ return true;
19
+ }
20
+ }
21
+ }
22
+ return false;
23
+ }
24
+
25
+ bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
26
+ for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
27
+ if (extra && extra->context) {
28
+ auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
29
+ auto tensor_traits = buf_extra->get_tensor_traits(op);
30
+ if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {
31
+ return true;
32
+ }
33
+ }
34
+ }
35
+ return false;
36
+ }
ggml/src/ggml-cpu/traits.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "ggml-backend-impl.h"
3
+ #include "ggml-cpu-impl.h"
4
+ #include "ggml.h"
5
+
6
+ #ifdef __cplusplus
7
+ # include <vector>
8
+ extern "C" {
9
+ #endif
10
+
11
+ // return true if op part of extra "accelerator"
12
+ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op);
13
+ bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size);
14
+
15
+ #ifdef __cplusplus
16
+ }
17
+
18
+ namespace ggml::cpu {
19
+ // register in tensor->extra
20
+ class tensor_traits {
21
+ public:
22
+ virtual ~tensor_traits();
23
+ virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size) = 0;
24
+ virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0;
25
+ };
26
+
27
+ class extra_buffer_type {
28
+ public:
29
+ virtual ~extra_buffer_type();
30
+ virtual bool supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0;
31
+ virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op) = 0;
32
+ };
33
+ } // namespace ggml::cpu
34
+
35
+ // implemented in ggml-cpu.cpp.
36
+ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
37
+
38
+ #endif
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
466
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
467
  }
468
 
469
- // TODO: move to ggml-common.h
470
- static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
471
-
472
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
473
 
474
  static __device__ __forceinline__ float get_alibi_slope(
 
466
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
467
  }
468
 
 
 
 
469
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
470
 
471
  static __device__ __forceinline__ float get_alibi_slope(
ggml/src/ggml-quants.c CHANGED
@@ -2425,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_REST
2425
  }
2426
  }
2427
 
2428
- static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
2429
-
2430
  void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2431
  assert(k % QK4_NL == 0);
2432
  const int64_t nb = k / QK4_NL;
 
2425
  }
2426
  }
2427
 
 
 
2428
  void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2429
  assert(k % QK4_NL == 0);
2430
  const int64_t nb = k / QK4_NL;
ggml/src/ggml-sycl/common.hpp CHANGED
@@ -149,8 +149,6 @@ typedef sycl::float2 dfloat2;
149
 
150
  #define MMVQ_MAX_BATCH_SIZE 8
151
 
152
- static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
153
-
154
  static int g_all_sycl_device_count = -1;
155
  static bool g_ggml_backend_sycl_buffer_type_initialized = false;
156
 
 
149
 
150
  #define MMVQ_MAX_BATCH_SIZE 8
151
 
 
 
152
  static int g_all_sycl_device_count = -1;
153
  static bool g_ggml_backend_sycl_buffer_type_initialized = false;
154