Spaces:
Sleeping
Sleeping
Prashant Vithule
vithulep
commited on
Commit
·
459beb1
1
Parent(s):
dcf68db
ggml: aarch64: implement SVE kernels for q2_k_q8_k vector dot (llama/12064)
Browse files* Added SVE Support for Q2_K Quantized Models
* Use 4-space indentation in the switch cases
* removed comments lines
* Remove the loop Retain the curly bracess for better understanding of code
* Remove the comment like added for q3_k_q8_k kernel
---------
Co-authored-by: vithulep <[email protected]>
ggml/src/ggml-cpu/ggml-cpu-quants.c
CHANGED
|
@@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 4587 |
|
| 4588 |
const int nb = n / QK_K;
|
| 4589 |
|
| 4590 |
-
#ifdef
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4591 |
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
| 4592 |
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
| 4593 |
|
|
|
|
| 4587 |
|
| 4588 |
const int nb = n / QK_K;
|
| 4589 |
|
| 4590 |
+
#ifdef __ARM_FEATURE_SVE
|
| 4591 |
+
const int vector_length = svcntb()*8;
|
| 4592 |
+
const svuint8_t m3s = svdup_n_u8(0x3);
|
| 4593 |
+
const svuint32_t m4s = svdup_n_u32(0xF);
|
| 4594 |
+
const svint32_t vzero_sv = svdup_n_s32(0);
|
| 4595 |
+
svfloat32_t acc_sum = svdup_n_f32(0);
|
| 4596 |
+
svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
|
| 4597 |
+
|
| 4598 |
+
switch (vector_length) {
|
| 4599 |
+
case 128:
|
| 4600 |
+
for (int i = 0; i < nb; ++i) {
|
| 4601 |
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 4602 |
+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
|
| 4603 |
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
| 4604 |
+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
|
| 4605 |
+
|
| 4606 |
+
const uint8_t * restrict q2 = x[i].qs;
|
| 4607 |
+
const int8_t * restrict q8_sv = y[i].qs;
|
| 4608 |
+
const uint8_t * restrict sc = x[i].scales;
|
| 4609 |
+
|
| 4610 |
+
svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
|
| 4611 |
+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
| 4612 |
+
|
| 4613 |
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
|
| 4614 |
+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
| 4615 |
+
|
| 4616 |
+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
|
| 4617 |
+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
|
| 4618 |
+
|
| 4619 |
+
const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
|
| 4620 |
+
|
| 4621 |
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
|
| 4622 |
+
const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
| 4623 |
+
|
| 4624 |
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
|
| 4625 |
+
const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
| 4626 |
+
|
| 4627 |
+
q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
|
| 4628 |
+
q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
|
| 4629 |
+
|
| 4630 |
+
svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
|
| 4631 |
+
|
| 4632 |
+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
|
| 4633 |
+
|
| 4634 |
+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
|
| 4635 |
+
|
| 4636 |
+
svint32_t sumi1 = svdup_n_s32(0);
|
| 4637 |
+
|
| 4638 |
+
{
|
| 4639 |
+
const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
|
| 4640 |
+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
|
| 4641 |
+
svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4642 |
+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
|
| 4643 |
+
|
| 4644 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
|
| 4645 |
+
|
| 4646 |
+
const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
|
| 4647 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
|
| 4648 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4649 |
+
|
| 4650 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
|
| 4651 |
+
|
| 4652 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
|
| 4653 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4654 |
+
|
| 4655 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
|
| 4656 |
+
|
| 4657 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
|
| 4658 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4659 |
+
|
| 4660 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
|
| 4661 |
+
|
| 4662 |
+
|
| 4663 |
+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
|
| 4664 |
+
|
| 4665 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
|
| 4666 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4667 |
+
|
| 4668 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
|
| 4669 |
+
|
| 4670 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
|
| 4671 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4672 |
+
|
| 4673 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
|
| 4674 |
+
|
| 4675 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
|
| 4676 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4677 |
+
|
| 4678 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
|
| 4679 |
+
|
| 4680 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
|
| 4681 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4682 |
+
|
| 4683 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
|
| 4684 |
+
|
| 4685 |
+
//-------------------------------
|
| 4686 |
+
|
| 4687 |
+
q2 += 32;
|
| 4688 |
+
const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
|
| 4689 |
+
const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
|
| 4690 |
+
|
| 4691 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
|
| 4692 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4693 |
+
|
| 4694 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
|
| 4695 |
+
|
| 4696 |
+
const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
|
| 4697 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
|
| 4698 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4699 |
+
|
| 4700 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
|
| 4701 |
+
|
| 4702 |
+
|
| 4703 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
|
| 4704 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4705 |
+
|
| 4706 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
|
| 4707 |
+
|
| 4708 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
|
| 4709 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4710 |
+
|
| 4711 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
|
| 4712 |
+
|
| 4713 |
+
|
| 4714 |
+
const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
|
| 4715 |
+
|
| 4716 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
|
| 4717 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4718 |
+
|
| 4719 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
|
| 4720 |
+
|
| 4721 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
|
| 4722 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4723 |
+
|
| 4724 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
|
| 4725 |
+
|
| 4726 |
+
|
| 4727 |
+
|
| 4728 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
|
| 4729 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4730 |
+
|
| 4731 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
|
| 4732 |
+
|
| 4733 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
|
| 4734 |
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
| 4735 |
+
|
| 4736 |
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
|
| 4737 |
+
}
|
| 4738 |
+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
|
| 4739 |
+
}
|
| 4740 |
+
*s = svaddv_f32(svptrue_b32(), acc_sum);
|
| 4741 |
+
break;
|
| 4742 |
+
|
| 4743 |
+
case 256:
|
| 4744 |
+
case 512:
|
| 4745 |
+
for (int i = 0; i < nb; ++i) {
|
| 4746 |
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 4747 |
+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
|
| 4748 |
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
| 4749 |
+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
|
| 4750 |
+
|
| 4751 |
+
const uint8_t * restrict q2 = x[i].qs;
|
| 4752 |
+
const int8_t * restrict q8_sv = y[i].qs;
|
| 4753 |
+
const uint8_t * restrict sc = x[i].scales;
|
| 4754 |
+
|
| 4755 |
+
const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
|
| 4756 |
+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
|
| 4757 |
+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
|
| 4758 |
+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
|
| 4759 |
+
|
| 4760 |
+
const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
|
| 4761 |
+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
|
| 4762 |
+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
|
| 4763 |
+
|
| 4764 |
+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
|
| 4765 |
+
|
| 4766 |
+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
|
| 4767 |
+
|
| 4768 |
+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
|
| 4769 |
+
|
| 4770 |
+
svint32_t sumi1 = svdup_n_s32(0);
|
| 4771 |
+
|
| 4772 |
+
{
|
| 4773 |
+
const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
|
| 4774 |
+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
|
| 4775 |
+
svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4776 |
+
|
| 4777 |
+
svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
|
| 4778 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
| 4779 |
+
|
| 4780 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
|
| 4781 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4782 |
+
|
| 4783 |
+
svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
|
| 4784 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
|
| 4785 |
+
|
| 4786 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
|
| 4787 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4788 |
+
|
| 4789 |
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
|
| 4790 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
| 4791 |
+
|
| 4792 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
|
| 4793 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4794 |
+
|
| 4795 |
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
|
| 4796 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
| 4797 |
+
|
| 4798 |
+
q2 += 32;
|
| 4799 |
+
|
| 4800 |
+
const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
|
| 4801 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
|
| 4802 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4803 |
+
|
| 4804 |
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
|
| 4805 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
| 4806 |
+
|
| 4807 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
|
| 4808 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4809 |
+
|
| 4810 |
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
|
| 4811 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
| 4812 |
+
|
| 4813 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
|
| 4814 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4815 |
+
|
| 4816 |
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
|
| 4817 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
| 4818 |
+
|
| 4819 |
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
|
| 4820 |
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
| 4821 |
+
|
| 4822 |
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
|
| 4823 |
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
| 4824 |
+
}
|
| 4825 |
+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
|
| 4826 |
+
}
|
| 4827 |
+
*s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
|
| 4828 |
+
break;
|
| 4829 |
+
|
| 4830 |
+
default:
|
| 4831 |
+
assert(false && "Unsupported vector length");
|
| 4832 |
+
break;
|
| 4833 |
+
}
|
| 4834 |
+
|
| 4835 |
+
#elif __ARM_NEON
|
| 4836 |
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
| 4837 |
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
| 4838 |
|