xctan commited on
Commit
4790d5d
·
1 Parent(s): d86ba47

ggml : riscv: add xtheadvector support (llama/13720)

Browse files

* ggml : riscv: add xtheadvector support

* ggml : clean up some macro usage

ggml/CMakeLists.txt CHANGED
@@ -129,6 +129,7 @@ option(GGML_LASX "ggml: enable lasx" ON)
129
  option(GGML_LSX "ggml: enable lsx" ON)
130
  option(GGML_RVV "ggml: enable rvv" ON)
131
  option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
 
132
  option(GGML_VXE "ggml: enable vxe" ON)
133
 
134
  option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
 
129
  option(GGML_LSX "ggml: enable lsx" ON)
130
  option(GGML_RVV "ggml: enable rvv" ON)
131
  option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
132
+ option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
133
  option(GGML_VXE "ggml: enable vxe" ON)
134
 
135
  option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -357,8 +357,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
357
  elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
358
  message(STATUS "RISC-V detected")
359
  if (GGML_RVV)
360
- if (GGML_RV_ZFH)
361
- list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d)
 
 
362
  else()
363
  list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
364
  endif()
 
357
  elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
358
  message(STATUS "RISC-V detected")
359
  if (GGML_RVV)
360
+ if (GGML_XTHEADVECTOR)
361
+ list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
362
+ elseif (GGML_RV_ZFH)
363
+ list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
364
  else()
365
  list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
366
  endif()
ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp CHANGED
@@ -1191,7 +1191,7 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
1191
  }
1192
  }
1193
  return;
1194
- #elif defined(__riscv_v_intrinsic)
1195
  if (__riscv_vlenb() >= QK4_0) {
1196
  const size_t vl = QK4_0;
1197
 
@@ -3783,7 +3783,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
3783
  }
3784
  return;
3785
  }
3786
- #elif defined(__riscv_v_intrinsic)
3787
  if (__riscv_vlenb() >= QK4_0) {
3788
  const size_t vl = QK4_0;
3789
 
 
1191
  }
1192
  }
1193
  return;
1194
+ #elif defined __riscv_v
1195
  if (__riscv_vlenb() >= QK4_0) {
1196
  const size_t vl = QK4_0;
1197
 
 
3783
  }
3784
  return;
3785
  }
3786
+ #elif defined __riscv_v
3787
  if (__riscv_vlenb() >= QK4_0) {
3788
  const size_t vl = QK4_0;
3789
 
ggml/src/ggml-cpu/ggml-cpu-impl.h CHANGED
@@ -320,21 +320,17 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
320
 
321
  #ifdef __wasm_simd128__
322
  #include <wasm_simd128.h>
323
- #else
 
324
  #ifdef __POWER9_VECTOR__
325
  #include <altivec.h>
326
- #else
 
327
  #if defined(_MSC_VER) || defined(__MINGW32__)
328
  #include <intrin.h>
329
- #else
330
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
331
- #if !defined(__riscv)
332
  #include <immintrin.h>
333
  #endif
334
- #endif
335
- #endif
336
- #endif
337
- #endif
338
 
339
  #ifdef __riscv_v_intrinsic
340
  #include <riscv_vector.h>
 
320
 
321
  #ifdef __wasm_simd128__
322
  #include <wasm_simd128.h>
323
+ #endif
324
+
325
  #ifdef __POWER9_VECTOR__
326
  #include <altivec.h>
327
+ #endif
328
+
329
  #if defined(_MSC_VER) || defined(__MINGW32__)
330
  #include <intrin.h>
331
+ #elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
 
 
332
  #include <immintrin.h>
333
  #endif
 
 
 
 
334
 
335
  #ifdef __riscv_v_intrinsic
336
  #include <riscv_vector.h>
ggml/src/ggml-cpu/ggml-cpu-quants.c CHANGED
@@ -883,7 +883,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
883
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
884
  #endif
885
  }
886
- #elif defined(__riscv_v_intrinsic)
887
 
888
  size_t vl = QK8_0;
889
 
@@ -1221,7 +1221,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
1221
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1222
  #endif
1223
  }
1224
- #elif defined(__riscv_v_intrinsic)
1225
 
1226
  size_t vl = QK8_1;
1227
 
@@ -2384,7 +2384,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
2384
  }
2385
 
2386
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2387
- #elif defined(__riscv_v_intrinsic)
2388
  size_t vl = qk / 2;
2389
 
2390
  for (; ib < nb; ++ib) {
@@ -2774,7 +2774,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
2774
  }
2775
 
2776
  sumf = hsum_float_8(acc) + summs;
2777
- #elif defined(__riscv_v_intrinsic)
2778
  size_t vl = qk / 2;
2779
 
2780
  for (; ib < nb; ++ib) {
@@ -3121,7 +3121,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3121
  }
3122
 
3123
  sumf = hsum_float_8(acc);
3124
- #elif defined(__riscv_v_intrinsic)
3125
  size_t vl;
3126
  size_t vlenb = __riscv_vlenb();
3127
 
@@ -3460,7 +3460,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
3460
  }
3461
 
3462
  sumf = hsum_float_8(acc) + summs;
3463
- #elif defined(__riscv_v_intrinsic)
3464
  size_t vl;
3465
  size_t vlenb = __riscv_vlenb();
3466
 
@@ -3897,7 +3897,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3897
  }
3898
 
3899
  sumf = hsum_float_8(accum);
3900
- #elif defined(__riscv_v_intrinsic)
3901
  size_t vl = qk;
3902
 
3903
  for (; ib < nb; ++ib) {
@@ -5100,14 +5100,111 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
5100
 
5101
  *s = sumf;
5102
 
5103
- #elif defined __riscv_v_intrinsic
5104
 
5105
- const int vector_length = __riscv_vlenb() * 8;
5106
  float sumf = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5108
  uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5109
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5110
- uint8_t atmp[16];
5111
 
5112
  switch (vector_length) {
5113
  case 256:
@@ -6137,13 +6234,140 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6137
 
6138
  *s = sumf;
6139
 
6140
- #elif defined __riscv_v_intrinsic
6141
 
6142
- uint32_t aux[3];
6143
  uint32_t utmp[4];
 
6144
 
6145
- const int vector_length = __riscv_vlenb() * 8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6146
  float sumf = 0;
 
 
6147
 
6148
  switch (vector_length) {
6149
  case 256:
@@ -6331,7 +6555,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6331
  "vslideup.vi v13, v14, 1\n\t"
6332
  "vslideup.vi v10, v8, 2\n\t"
6333
  "vslideup.vi v11, v13, 2\n\t"
6334
- "vsetivli zero, 8, e32, m2\n\t"\
6335
  "vle8.v v15, (%[scale])\n\t"
6336
  "vsext.vf4 v12, v15\n\t"
6337
  "vmul.vv v10, v10, v12\n\t"
@@ -7180,14 +7404,130 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7180
 
7181
  *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
7182
 
7183
- #elif defined __riscv_v_intrinsic
7184
 
7185
  const uint8_t * scales = (const uint8_t*)&utmp[0];
7186
  const uint8_t * mins = (const uint8_t*)&utmp[2];
7187
 
7188
- const int vector_length = __riscv_vlenb() * 8;
7189
  float sumf = 0;
7190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7191
  switch (vector_length) {
7192
  case 256:
7193
  for (int i = 0; i < nb; ++i) {
@@ -8074,7 +8414,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8074
 
8075
  *s = sumf;
8076
 
8077
- #elif defined __riscv_v_intrinsic
8078
 
8079
  const uint8_t * scales = (const uint8_t*)&utmp[0];
8080
  const uint8_t * mins = (const uint8_t*)&utmp[2];
@@ -9232,10 +9572,91 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
9232
  }
9233
  *s = sumf;
9234
 
9235
- #elif defined __riscv_v_intrinsic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9236
 
9237
- const int vector_length = __riscv_vlenb() * 8;
9238
  float sumf = 0;
 
9239
 
9240
  switch (vector_length) {
9241
  case 256:
 
883
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
884
  #endif
885
  }
886
+ #elif defined(__riscv_v)
887
 
888
  size_t vl = QK8_0;
889
 
 
1221
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1222
  #endif
1223
  }
1224
+ #elif defined(__riscv_v)
1225
 
1226
  size_t vl = QK8_1;
1227
 
 
2384
  }
2385
 
2386
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2387
+ #elif defined(__riscv_v)
2388
  size_t vl = qk / 2;
2389
 
2390
  for (; ib < nb; ++ib) {
 
2774
  }
2775
 
2776
  sumf = hsum_float_8(acc) + summs;
2777
+ #elif defined(__riscv_v)
2778
  size_t vl = qk / 2;
2779
 
2780
  for (; ib < nb; ++ib) {
 
3121
  }
3122
 
3123
  sumf = hsum_float_8(acc);
3124
+ #elif defined(__riscv_v)
3125
  size_t vl;
3126
  size_t vlenb = __riscv_vlenb();
3127
 
 
3460
  }
3461
 
3462
  sumf = hsum_float_8(acc) + summs;
3463
+ #elif defined(__riscv_v)
3464
  size_t vl;
3465
  size_t vlenb = __riscv_vlenb();
3466
 
 
3897
  }
3898
 
3899
  sumf = hsum_float_8(accum);
3900
+ #elif defined(__riscv_v)
3901
  size_t vl = qk;
3902
 
3903
  for (; ib < nb; ++ib) {
 
5100
 
5101
  *s = sumf;
5102
 
5103
+ #elif defined __riscv_xtheadvector
5104
 
 
5105
  float sumf = 0;
5106
+ uint8_t atmp[16];
5107
+
5108
+ for (int i = 0; i < nb; ++i) {
5109
+ const uint8_t * q2 = x[i].qs;
5110
+ const int8_t * q8 = y[i].qs;
5111
+ const uint8_t * sc = x[i].scales;
5112
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5113
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5114
+ uint8_t *patmp = atmp;
5115
+ int vsums;
5116
+ int tmp;
5117
+ __asm__ __volatile__(
5118
+ "th.vsetvli zero, %[vl16], e8, m1\n\t"
5119
+ "th.vmv.v.x v8, zero\n\t"
5120
+ "th.vlb.v v1, (%[sc])\n\t"
5121
+ "th.vand.vi v0, v1, 0xF\n\t"
5122
+ "th.vsrl.vi v1, v1, 4\n\t"
5123
+ "th.vsb.v v0, (%[scale])\n\t"
5124
+ "th.vwaddu.vx v16, v1, zero\n\t"
5125
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
5126
+ "th.vlh.v v2, (%[bsums])\n\t"
5127
+ "th.vwmul.vv v4, v16, v2\n\t"
5128
+ "th.vsetvli zero, %[vl16], e32, m4\n\t"
5129
+ "th.vredsum.vs v8, v4, v8\n\t"
5130
+ "th.vmv.x.s %[vsums], v8"
5131
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
5132
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
5133
+ , [vl16] "r" (16)
5134
+ : "memory"
5135
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5136
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5137
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5138
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5139
+ );
5140
+ sumf += dmin * vsums;
5141
+ int isum = 0;
5142
 
5143
+ for (int j = 0; j < QK_K/128; ++j) {
5144
+ __asm__ __volatile__(
5145
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
5146
+ "th.vlb.v v0, (%[q2])\n\t"
5147
+ "th.vsrl.vi v2, v0, 2\n\t"
5148
+ "th.vsrl.vi v4, v0, 4\n\t"
5149
+ "th.vsrl.vi v6, v0, 6\n\t"
5150
+ "th.vand.vi v0, v0, 0x3\n\t"
5151
+ "th.vand.vi v2, v2, 0x3\n\t"
5152
+ "th.vand.vi v4, v4, 0x3\n\t"
5153
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
5154
+ "th.vlb.v v8, (%[q8])\n\t"
5155
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
5156
+ "th.vwmul.vv v16, v0, v8\n\t"
5157
+ "th.vwmul.vv v24, v4, v12\n\t"
5158
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
5159
+ "th.vmv.v.x v0, zero\n\t"
5160
+ "th.vwredsum.vs v10, v16, v0\n\t"
5161
+ "th.vwredsum.vs v9, v18, v0\n\t"
5162
+ "th.vwredsum.vs v8, v20, v0\n\t"
5163
+ "th.vwredsum.vs v7, v22, v0\n\t"
5164
+ "th.vwredsum.vs v11, v24, v0\n\t"
5165
+ "th.vwredsum.vs v12, v26, v0\n\t"
5166
+ "th.vwredsum.vs v13, v28, v0\n\t"
5167
+ "th.vwredsum.vs v14, v30, v0\n\t"
5168
+ "li %[tmp], 4\n\t"
5169
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
5170
+ "th.vslideup.vi v10, v9, 1\n\t"
5171
+ "th.vslideup.vi v8, v7, 1\n\t"
5172
+ "th.vslideup.vi v11, v12, 1\n\t"
5173
+ "th.vslideup.vi v13, v14, 1\n\t"
5174
+ "th.vslideup.vi v10, v8, 2\n\t"
5175
+ "th.vslideup.vi v11, v13, 2\n\t"
5176
+ "li %[tmp], 8\n\t"
5177
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
5178
+ "th.vlbu.v v12, (%[scale])\n\t"
5179
+ "th.vmul.vv v10, v10, v12\n\t"
5180
+ "th.vredsum.vs v0, v10, v0\n\t"
5181
+ "th.vmv.x.s %[tmp], v0\n\t"
5182
+ "add %[isum], %[isum], %[tmp]"
5183
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
5184
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
5185
+ , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
5186
+ : "memory"
5187
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5188
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5189
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5190
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5191
+ );
5192
+ q2 += 32; q8 += 128; patmp += 8;
5193
+ }
5194
+
5195
+ sumf += dall * isum;
5196
+ }
5197
+
5198
+ *s = sumf;
5199
+
5200
+ #elif defined __riscv_v
5201
+
5202
+ float sumf = 0;
5203
+ uint8_t atmp[16];
5204
+
5205
+ const int vector_length = __riscv_vlenb() * 8;
5206
  uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5207
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
 
5208
 
5209
  switch (vector_length) {
5210
  case 256:
 
6234
 
6235
  *s = sumf;
6236
 
6237
+ #elif defined __riscv_xtheadvector
6238
 
 
6239
  uint32_t utmp[4];
6240
+ float sumf = 0;
6241
 
6242
+ for (int i = 0; i < nb; ++i) {
6243
+ const uint8_t * restrict q3 = x[i].qs;
6244
+ const uint8_t * restrict qh = x[i].hmask;
6245
+ const int8_t * restrict q8 = y[i].qs;
6246
+
6247
+ int8_t * scale = (int8_t *)utmp;
6248
+ int tmp;
6249
+ __asm__ __volatile__(
6250
+ "li %[tmp], 12\n\t"
6251
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
6252
+ "th.vlb.v v0, (%[s6b])\n\t"
6253
+ "th.vmv.v.v v2, v0\n\t"
6254
+ "li %[tmp], 2\n\t"
6255
+ "th.vsetvli zero, %[tmp], e64, m1\n\t"
6256
+ "th.vmv.v.x v9, %[sh]\n\t"\
6257
+ "th.vslidedown.vi v1, v0, 1\n\t"
6258
+ "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
6259
+ "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
6260
+ "li %[tmp], 4\n\t"
6261
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
6262
+ "th.vid.v v9\n\t"
6263
+ "th.vmv.x.s %[tmp], v1\n\t"
6264
+ "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
6265
+ "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
6266
+ "th.vsrl.vv v4, v1, v9\n\t"
6267
+ "th.vsrl.vv v2, v0, v8\n\t"
6268
+ "th.vand.vx v5, v4, %[kmask1]\n\t"
6269
+ "th.vand.vx v3, v2, %[kmask2]\n\t"
6270
+ "th.vsll.vi v6, v5, 4\n\t"
6271
+ "th.vor.vv v7, v6, v3\n\t"
6272
+ "li %[tmp], 16\n\t"
6273
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
6274
+ "th.vsub.vx v0, v7, %[c]\n\t"
6275
+ "th.vsb.v v0, (%[scale])"
6276
+ : [tmp] "=&r" (tmp)
6277
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
6278
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
6279
+ : "memory"
6280
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6281
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6282
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6283
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6284
+ );
6285
+
6286
+ uint8_t m = 1;
6287
+ int isum = 0;
6288
+ for (int j = 0; j < QK_K; j += 128) {
6289
+ __asm__ __volatile__(
6290
+ // fixme: use v0p7 mask layout directly
6291
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
6292
+ "th.vlb.v v8, (%[q3])\n\t"
6293
+ "th.vsrl.vi v10, v8, 2\n\t"
6294
+ "th.vsrl.vi v12, v8, 4\n\t"
6295
+ "th.vsrl.vi v14, v8, 6\n\t"
6296
+ "th.vand.vi v8, v8, 3\n\t"
6297
+ "th.vand.vi v10, v10, 3\n\t"
6298
+ "th.vand.vi v12, v12, 3\n\t"
6299
+ "th.vlb.v v2, (%[qh])\n\t"
6300
+ "th.vand.vx v4, v2, %[m]\n\t"
6301
+ "slli %[m], %[m], 1\n\t"
6302
+ "th.vmseq.vx v0, v4, zero\n\t"
6303
+ "th.vadd.vi v8, v8, -4, v0.t\n\t"
6304
+ "th.vand.vx v4, v2, %[m]\n\t"
6305
+ "slli %[m], %[m], 1\n\t"
6306
+ "th.vmseq.vx v0, v4, zero\n\t"
6307
+ "th.vadd.vi v10, v10, -4, v0.t\n\t"
6308
+ "th.vand.vx v4, v2, %[m]\n\t"
6309
+ "slli %[m], %[m], 1\n\t"
6310
+ "th.vmseq.vx v0, v4, zero\n\t"
6311
+ "th.vadd.vi v12, v12, -4, v0.t\n\t"
6312
+ "th.vand.vx v4, v2, %[m]\n\t"
6313
+ "slli %[m], %[m], 1\n\t"
6314
+ "th.vmseq.vx v0, v4, zero\n\t"
6315
+ "th.vadd.vi v14, v14, -4, v0.t\n\t"
6316
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
6317
+ "th.vlb.v v0, (%[q8])\n\t"
6318
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
6319
+ "th.vwmul.vv v16, v0, v8\n\t"
6320
+ "th.vwmul.vv v24, v4, v12\n\t"
6321
+ "li %[tmp], 16\n\t"
6322
+ "th.vsetvli zero, %[tmp], e16, m2\n\t"
6323
+ "th.vmv.v.x v0, zero\n\t"
6324
+ "th.vwredsum.vs v10, v16, v0\n\t"
6325
+ "th.vwredsum.vs v9, v18, v0\n\t"
6326
+ "th.vwredsum.vs v8, v20, v0\n\t"
6327
+ "th.vwredsum.vs v7, v22, v0\n\t"
6328
+ "th.vwredsum.vs v11, v24, v0\n\t"
6329
+ "th.vwredsum.vs v12, v26, v0\n\t"
6330
+ "th.vwredsum.vs v13, v28, v0\n\t"
6331
+ "th.vwredsum.vs v14, v30, v0\n\t"
6332
+ "li %[tmp], 4\n\t"
6333
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
6334
+ "th.vslideup.vi v10, v9, 1\n\t"
6335
+ "th.vslideup.vi v8, v7, 1\n\t"
6336
+ "th.vslideup.vi v11, v12, 1\n\t"
6337
+ "th.vslideup.vi v13, v14, 1\n\t"
6338
+ "th.vslideup.vi v10, v8, 2\n\t"
6339
+ "th.vslideup.vi v11, v13, 2\n\t"
6340
+ "li %[tmp], 8\n\t"
6341
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
6342
+ "th.vlb.v v12, (%[scale])\n\t"
6343
+ "th.vmul.vv v10, v10, v12\n\t"
6344
+ "th.vredsum.vs v0, v10, v0\n\t"
6345
+ "th.vmv.x.s %[tmp], v0\n\t"
6346
+ "add %[isum], %[isum], %[tmp]"
6347
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
6348
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
6349
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
6350
+ : "memory"
6351
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6352
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6353
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6354
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6355
+ );
6356
+ q3 += 32; q8 += 128; scale += 8;
6357
+ }
6358
+
6359
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6360
+ sumf += d * isum;
6361
+ }
6362
+
6363
+ *s = sumf;
6364
+
6365
+ #elif defined __riscv_v
6366
+
6367
+ uint32_t utmp[4];
6368
  float sumf = 0;
6369
+ uint32_t aux[3];
6370
+ const int vector_length = __riscv_vlenb() * 8;
6371
 
6372
  switch (vector_length) {
6373
  case 256:
 
6555
  "vslideup.vi v13, v14, 1\n\t"
6556
  "vslideup.vi v10, v8, 2\n\t"
6557
  "vslideup.vi v11, v13, 2\n\t"
6558
+ "vsetivli zero, 8, e32, m2\n\t"
6559
  "vle8.v v15, (%[scale])\n\t"
6560
  "vsext.vf4 v12, v15\n\t"
6561
  "vmul.vv v10, v10, v12\n\t"
 
7404
 
7405
  *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
7406
 
7407
+ #elif defined __riscv_xtheadvector
7408
 
7409
  const uint8_t * scales = (const uint8_t*)&utmp[0];
7410
  const uint8_t * mins = (const uint8_t*)&utmp[2];
7411
 
 
7412
  float sumf = 0;
7413
 
7414
+ for (int i = 0; i < nb; ++i) {
7415
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7416
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7417
+
7418
+ int tmp, tmp2, sumi;
7419
+ __asm__ __volatile__(
7420
+ "li %[t1], 12\n\t"
7421
+ "th.vsetvli zero, %[t1], e8, m1\n\t"
7422
+ "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
7423
+ "li %[t1], 4\n\t"
7424
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
7425
+ "th.vslidedown.vi v2, v1, 2\n\t"
7426
+ "th.vmv.v.v v3, v2\n\t"
7427
+ "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
7428
+ "li %[t1], 2\n\t"
7429
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
7430
+ "th.vmv.v.i v4, 4\n\t"
7431
+ "th.vand.vx v8, v1, %[kmask1]\n\t"
7432
+ "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
7433
+ "th.vsrl.vi v6, v1, 6\n\t"
7434
+ "th.vsrl.vv v7, v2, v5\n\t"
7435
+ "th.vand.vx v0, v6, %[kmask3]\n\t"
7436
+ "th.vand.vx v2, v7, %[kmask2]\n\t"
7437
+ "th.vsll.vi v6, v0, 4\n\t"
7438
+ "li %[t2], 8\n\t"
7439
+ "addi %[t1], %[utmp], 4\n\t"
7440
+ "th.vor.vv v1, v6, v2\n\t"
7441
+ "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
7442
+ "th.vssw.v v1, (%[t1]), %[t2]\n\t"
7443
+ "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
7444
+ "th.vlw.v v2, (%[bsums])\n\t"
7445
+ "th.vsetvli zero, %[t2], e16, m1\n\t"
7446
+ "th.vnsrl.vi v0, v2, 0\n\t"
7447
+ "th.vnsrl.vi v1, v2, 16\n\t"
7448
+ "th.vadd.vv v2, v0, v1\n\t"
7449
+ "th.vlbu.v v4, (%[mins])\n\t"
7450
+ "th.vwmul.vv v6, v4, v2\n\t"
7451
+ "th.vmv.v.x v0, zero\n\t"
7452
+ "th.vsetvli zero, %[t2], e32, m2\n\t"
7453
+ "th.vredsum.vs v0, v6, v0\n\t"
7454
+ "th.vmv.x.s %[sumi], v0"
7455
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
7456
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
7457
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
7458
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
7459
+ : "memory"
7460
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7461
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7462
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7463
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7464
+ );
7465
+ sumf -= dmin * sumi;
7466
+
7467
+ const uint8_t * restrict q4 = x[i].qs;
7468
+ const int8_t * restrict q8 = y[i].qs;
7469
+
7470
+ sumi = 0;
7471
+ const uint8_t * scale = scales;
7472
+
7473
+ for (int j = 0; j < QK_K/128; ++j) {
7474
+ int vl128 = 128, vl64 = 64, vl32 = 32;
7475
+ __asm__ __volatile__(
7476
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
7477
+ "th.vlb.v v8, (%[q8])\n\t"
7478
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
7479
+ "th.vlb.v v0, (%[q4])\n\t"
7480
+ "th.vsrl.vi v4, v0, 4\n\t"
7481
+ "th.vand.vi v0, v0, 0xF\n\t"
7482
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
7483
+ "th.vwmul.vv v28, v6, v14\n\t"
7484
+ "th.vwmul.vv v20, v4, v10\n\t"
7485
+ "th.vwmul.vv v24, v2, v12\n\t"
7486
+ "th.vwmul.vv v16, v0, v8\n\t"
7487
+ "li %[tmp], 4\n\t"
7488
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
7489
+ "th.vlbu.v v1, (%[scale])\n\t"
7490
+ "th.vmv.v.x v0, zero\n\t"
7491
+ "th.vsetvli zero, %[vl32], e16, m4\n\t"
7492
+ "th.vwredsum.vs v6, v24, v0\n\t"
7493
+ "th.vwredsum.vs v7, v28, v0\n\t"
7494
+ "th.vwredsum.vs v4, v16, v0\n\t"
7495
+ "th.vwredsum.vs v5, v20, v0\n\t"
7496
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
7497
+ "th.vslideup.vi v6, v7, 1\n\t"
7498
+ "th.vslideup.vi v4, v5, 1\n\t"
7499
+ "th.vslideup.vi v4, v6, 2\n\t"
7500
+ "th.vmul.vv v8, v4, v1\n\t"
7501
+ "th.vredsum.vs v0, v8, v0\n\t"
7502
+ "th.vmv.x.s %[tmp], v0\n\t"
7503
+ "add %[sumi], %[sumi], %[tmp]"
7504
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
7505
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
7506
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
7507
+ : "memory"
7508
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7509
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7510
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7511
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7512
+ );
7513
+
7514
+ q4 += 64; q8 += 128; scale += 4;
7515
+ }
7516
+
7517
+ sumf += d * sumi;
7518
+
7519
+ }
7520
+
7521
+ *s = sumf;
7522
+
7523
+ #elif defined __riscv_v
7524
+
7525
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
7526
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
7527
+
7528
+ float sumf = 0;
7529
+ const int vector_length = __riscv_vlenb() * 8;
7530
+
7531
  switch (vector_length) {
7532
  case 256:
7533
  for (int i = 0; i < nb; ++i) {
 
8414
 
8415
  *s = sumf;
8416
 
8417
+ #elif defined __riscv_v
8418
 
8419
  const uint8_t * scales = (const uint8_t*)&utmp[0];
8420
  const uint8_t * mins = (const uint8_t*)&utmp[2];
 
9572
  }
9573
  *s = sumf;
9574
 
9575
+ #elif defined __riscv_xtheadvector
9576
+
9577
+ float sumf = 0;
9578
+
9579
+ for (int i = 0; i < nb; ++i) {
9580
+
9581
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9582
+
9583
+ const uint8_t * restrict q6 = x[i].ql;
9584
+ const uint8_t * restrict qh = x[i].qh;
9585
+ const int8_t * restrict q8 = y[i].qs;
9586
+
9587
+ const int8_t * restrict scale = x[i].scales;
9588
+
9589
+ int sum_t = 0;
9590
+ int t0;
9591
+
9592
+ for (int j = 0; j < QK_K/128; ++j) {
9593
+ __asm__ __volatile__(
9594
+ "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
9595
+ "th.vlb.v v4, (%[qh])\n\t"
9596
+ "th.vsll.vi v0, v4, 4\n\t"
9597
+ "th.vsll.vi v2, v4, 2\n\t"
9598
+ "th.vsrl.vi v6, v4, 2\n\t"
9599
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
9600
+ "th.vlb.v v8, (%[q6])\n\t"
9601
+ "th.vsrl.vi v12, v8, 4\n\t"
9602
+ "th.vand.vi v8, v8, 0xF\n\t"
9603
+ "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
9604
+ "th.vand.vx v0, v0, %[mask]\n\t"
9605
+ "th.vor.vv v8, v8, v0\n\t"
9606
+ "th.vlb.v v0, (%[q8])\n\t"
9607
+ "th.vsub.vx v8, v8, %[vl32]\n\t"
9608
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
9609
+ "th.vwmul.vv v16, v0, v8\n\t"
9610
+ "th.vwmul.vv v24, v4, v12\n\t"
9611
+ "li %[t0], 16\n\t"
9612
+ "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
9613
+ "th.vmv.v.x v0, zero\n\t"
9614
+ "th.vwredsum.vs v10, v16, v0\n\t"
9615
+ "th.vwredsum.vs v9, v18, v0\n\t"
9616
+ "th.vwredsum.vs v8, v20, v0\n\t"
9617
+ "th.vwredsum.vs v7, v22, v0\n\t"
9618
+ "th.vwredsum.vs v11, v24, v0\n\t"
9619
+ "th.vwredsum.vs v12, v26, v0\n\t"
9620
+ "th.vwredsum.vs v13, v28, v0\n\t"
9621
+ "th.vwredsum.vs v14, v30, v0\n\t"
9622
+ "li %[t0], 4\n\t"
9623
+ "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
9624
+ "th.vslideup.vi v10, v9, 1\n\t"
9625
+ "th.vslideup.vi v8, v7, 1\n\t"
9626
+ "th.vslideup.vi v11, v12, 1\n\t"
9627
+ "th.vslideup.vi v13, v14, 1\n\t"
9628
+ "th.vslideup.vi v10, v8, 2\n\t"
9629
+ "th.vslideup.vi v11, v13, 2\n\t"
9630
+ "li %[t0], 8\n\t"
9631
+ "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
9632
+ "th.vlb.v v4, (%[scale])\n\t"
9633
+ "th.vmul.vv v2, v4, v10\n\t"
9634
+ "th.vredsum.vs v0, v2, v0\n\t"
9635
+ "th.vmv.x.s %[t0], v0\n\t"
9636
+ "add %[sumi], %[sumi], %[t0]"
9637
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
9638
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
9639
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
9640
+ , [mask] "r" (0x30)
9641
+ : "memory"
9642
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
9643
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
9644
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
9645
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
9646
+ );
9647
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
9648
+ }
9649
+
9650
+ sumf += d * sum_t;
9651
+
9652
+ }
9653
+
9654
+ *s = sumf;
9655
+
9656
+ #elif defined __riscv_v
9657
 
 
9658
  float sumf = 0;
9659
+ const int vector_length = __riscv_vlenb() * 8;
9660
 
9661
  switch (vector_length) {
9662
  case 256:
ggml/src/ggml-impl.h CHANGED
@@ -386,7 +386,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
386
  return r;
387
  }
388
 
389
- #elif defined(__riscv) && defined(GGML_RV_ZFH)
390
 
391
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
392
  float f;
 
386
  return r;
387
  }
388
 
389
+ #elif defined(__riscv) && defined(__riscv_zfhmin)
390
 
391
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
392
  float f;