Spaces:
Running
Running
ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (llama/11227)
Browse files* Add SVE support for q4_K_q8_K
* Update ggml/src/ggml-cpu/ggml-cpu-quants.c
change to use K_SCALE_SIZE
Co-authored-by: Georgi Gerganov <[email protected]>
---------
Co-authored-by: Georgi Gerganov <[email protected]>
ggml/src/ggml-cpu/ggml-cpu-quants.c
CHANGED
|
@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 5573 |
|
| 5574 |
uint32_t utmp[4];
|
| 5575 |
|
| 5576 |
-
#ifdef
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5577 |
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
| 5578 |
const int32x4_t mzero = vdupq_n_s32(0);
|
| 5579 |
|
|
|
|
| 5573 |
|
| 5574 |
uint32_t utmp[4];
|
| 5575 |
|
| 5576 |
+
#ifdef __ARM_FEATURE_SVE
|
| 5577 |
+
float sumf = 0;
|
| 5578 |
+
for (int i = 0; i < nb; ++i) {
|
| 5579 |
+
|
| 5580 |
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 5581 |
+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
| 5582 |
+
|
| 5583 |
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
| 5584 |
+
|
| 5585 |
+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
|
| 5586 |
+
|
| 5587 |
+
uint32x2_t mins8 = { 0 };
|
| 5588 |
+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
| 5589 |
+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
| 5590 |
+
|
| 5591 |
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
| 5592 |
+
utmp[0] &= kmask1;
|
| 5593 |
+
|
| 5594 |
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
| 5595 |
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
| 5596 |
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
| 5597 |
+
sumf -= dmin * vaddvq_s32(prod);
|
| 5598 |
+
|
| 5599 |
+
const uint8_t * scales = (const uint8_t *)utmp;
|
| 5600 |
+
|
| 5601 |
+
const uint8_t * restrict q4 = x[i].qs;
|
| 5602 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 5603 |
+
|
| 5604 |
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
| 5605 |
+
const svuint8_t m4b = svdup_n_u8(0xf);
|
| 5606 |
+
const svint32_t mzero = svdup_n_s32(0);
|
| 5607 |
+
svint32_t sumi1 = svdup_n_s32(0);
|
| 5608 |
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
| 5609 |
+
svint32_t sumi1_2 = svdup_n_s32(0);
|
| 5610 |
+
svint32_t sumi2 = svdup_n_s32(0);
|
| 5611 |
+
svint32_t sumi2_1 = svdup_n_s32(0);
|
| 5612 |
+
svint32_t sumi2_2 = svdup_n_s32(0);
|
| 5613 |
+
switch (vector_length) {
|
| 5614 |
+
case 128:
|
| 5615 |
+
{
|
| 5616 |
+
for (int j = 0; j < QK_K/64; ++j) {
|
| 5617 |
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
|
| 5618 |
+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
| 5619 |
+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
| 5620 |
+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
|
| 5621 |
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
| 5622 |
+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
| 5623 |
+
|
| 5624 |
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
|
| 5625 |
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
| 5626 |
+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
| 5627 |
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
|
| 5628 |
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
| 5629 |
+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
| 5630 |
+
q4 += 32;
|
| 5631 |
+
}
|
| 5632 |
+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
|
| 5633 |
+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
|
| 5634 |
+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
|
| 5635 |
+
} break;
|
| 5636 |
+
case 256:
|
| 5637 |
+
case 512:
|
| 5638 |
+
{
|
| 5639 |
+
for (int j = 0; j < QK_K/64; ++j) {
|
| 5640 |
+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
|
| 5641 |
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
|
| 5642 |
+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
| 5643 |
+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
| 5644 |
+
|
| 5645 |
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
|
| 5646 |
+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
| 5647 |
+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
| 5648 |
+
}
|
| 5649 |
+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
|
| 5650 |
+
} break;
|
| 5651 |
+
default:
|
| 5652 |
+
assert(false && "Unsupported vector length");
|
| 5653 |
+
break;
|
| 5654 |
+
}
|
| 5655 |
+
}
|
| 5656 |
+
*s = sumf;
|
| 5657 |
+
#elif __ARM_NEON
|
| 5658 |
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
| 5659 |
const int32x4_t mzero = vdupq_n_s32(0);
|
| 5660 |
|