fj-y-saito ggerganov commited on
Commit
bf3dc93
·
1 Parent(s): 03ab36f

ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (llama/11227)

Browse files

* Add SVE support for q4_K_q8_K

* Update ggml/src/ggml-cpu/ggml-cpu-quants.c

change to use K_SCALE_SIZE

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-cpu/ggml-cpu-quants.c +82 -1
ggml/src/ggml-cpu/ggml-cpu-quants.c CHANGED
@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5573
 
5574
  uint32_t utmp[4];
5575
 
5576
- #ifdef __ARM_NEON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5577
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5578
  const int32x4_t mzero = vdupq_n_s32(0);
5579
 
 
5573
 
5574
  uint32_t utmp[4];
5575
 
5576
+ #ifdef __ARM_FEATURE_SVE
5577
+ float sumf = 0;
5578
+ for (int i = 0; i < nb; ++i) {
5579
+
5580
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5581
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5582
+
5583
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
5584
+
5585
+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
5586
+
5587
+ uint32x2_t mins8 = { 0 };
5588
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
5589
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
5590
+
5591
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
5592
+ utmp[0] &= kmask1;
5593
+
5594
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
5595
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
5596
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
5597
+ sumf -= dmin * vaddvq_s32(prod);
5598
+
5599
+ const uint8_t * scales = (const uint8_t *)utmp;
5600
+
5601
+ const uint8_t * restrict q4 = x[i].qs;
5602
+ const int8_t * restrict q8 = y[i].qs;
5603
+
5604
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
5605
+ const svuint8_t m4b = svdup_n_u8(0xf);
5606
+ const svint32_t mzero = svdup_n_s32(0);
5607
+ svint32_t sumi1 = svdup_n_s32(0);
5608
+ svint32_t sumi1_1 = svdup_n_s32(0);
5609
+ svint32_t sumi1_2 = svdup_n_s32(0);
5610
+ svint32_t sumi2 = svdup_n_s32(0);
5611
+ svint32_t sumi2_1 = svdup_n_s32(0);
5612
+ svint32_t sumi2_2 = svdup_n_s32(0);
5613
+ switch (vector_length) {
5614
+ case 128:
5615
+ {
5616
+ for (int j = 0; j < QK_K/64; ++j) {
5617
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
5618
+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5619
+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5620
+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
5621
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5622
+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5623
+
5624
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
5625
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5626
+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5627
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
5628
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5629
+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5630
+ q4 += 32;
5631
+ }
5632
+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
5633
+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
5634
+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
5635
+ } break;
5636
+ case 256:
5637
+ case 512:
5638
+ {
5639
+ for (int j = 0; j < QK_K/64; ++j) {
5640
+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
5641
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
5642
+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5643
+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5644
+
5645
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
5646
+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5647
+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5648
+ }
5649
+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
5650
+ } break;
5651
+ default:
5652
+ assert(false && "Unsupported vector length");
5653
+ break;
5654
+ }
5655
+ }
5656
+ *s = sumf;
5657
+ #elif __ARM_NEON
5658
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5659
  const int32x4_t mzero = vdupq_n_s32(0);
5660