Max Krasnyansky ggerganov commited on
Commit
c917076
·
1 Parent(s): 6f8daf7

Add support for properly optimized Windows ARM64 builds with LLVM and MSVC (llama/7191)

Browse files

* logging: add proper checks for clang to avoid errors and warnings with VA_ARGS

* build: add CMake Presets and toolchian files for Windows ARM64

* matmul-int8: enable matmul-int8 with MSVC and fix Clang warnings

* ci: add support for optimized Windows ARM64 builds with MSVC and LLVM

* matmul-int8: fixed typos in q8_0_q8_0 matmuls

Co-authored-by: Georgi Gerganov <[email protected]>

* matmul-int8: remove unnecessary casts in q8_0_q8_0

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (1) hide show
  1. ggml-quants.c +28 -25
ggml-quants.c CHANGED
@@ -3487,10 +3487,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3487
  #if defined(__ARM_FEATURE_MATMUL_INT8)
3488
  if (nrc == 2) {
3489
  const block_q4_0 * restrict vx0 = vx;
3490
- const block_q4_0 * restrict vx1 = vx + bx;
3491
-
3492
  const block_q8_0 * restrict vy0 = vy;
3493
- const block_q8_0 * restrict vy1 = vy + by;
3494
 
3495
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3496
 
@@ -3524,10 +3523,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3524
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3525
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3526
 
3527
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3528
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3529
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3530
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
 
 
3531
 
3532
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3533
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3894,9 +3895,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3894
  #if defined(__ARM_FEATURE_MATMUL_INT8)
3895
  if (nrc == 2) {
3896
  const block_q4_1 * restrict vx0 = vx;
3897
- const block_q4_1 * restrict vx1 = vx + bx;
3898
  const block_q8_1 * restrict vy0 = vy;
3899
- const block_q8_1 * restrict vy1 = vy + by;
3900
 
3901
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3902
  float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3907,11 +3908,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3907
  const block_q8_1 * restrict b_y0 = &vy0[i];
3908
  const block_q8_1 * restrict b_y1 = &vy1[i];
3909
 
3910
- float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3911
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3912
- GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3913
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3914
- summs0 += summs_t;
3915
 
3916
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
3917
 
@@ -3931,10 +3932,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3931
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3932
 
3933
  // mmla into int32x4_t
3934
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3935
- GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3936
- GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3937
- GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
 
3938
 
3939
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3940
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3953,7 +3955,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3953
 
3954
  float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3955
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3956
- sumv2 = sumv2 + summs0;
3957
 
3958
  vst1_f32(s, vget_low_f32(sumv2));
3959
  vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4837,9 +4839,9 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4837
  #if defined(__ARM_FEATURE_MATMUL_INT8)
4838
  if (nrc == 2) {
4839
  const block_q8_0 * restrict vx0 = vx;
4840
- const block_q8_0 * restrict vx1 = vx + bx;
4841
  const block_q8_0 * restrict vy0 = vy;
4842
- const block_q8_0 * restrict vy1 = vy + by;
4843
 
4844
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4845
 
@@ -4861,10 +4863,11 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4861
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4862
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4863
 
4864
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4865
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4866
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4867
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
 
4868
 
4869
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4870
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 
3487
  #if defined(__ARM_FEATURE_MATMUL_INT8)
3488
  if (nrc == 2) {
3489
  const block_q4_0 * restrict vx0 = vx;
3490
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
 
3491
  const block_q8_0 * restrict vy0 = vy;
3492
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
3493
 
3494
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3495
 
 
3523
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3524
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3525
 
3526
+ float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3527
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3528
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3529
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3530
+
3531
+ float32x4_t scale = vld1q_f32(_scale);
3532
 
3533
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3534
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 
3895
  #if defined(__ARM_FEATURE_MATMUL_INT8)
3896
  if (nrc == 2) {
3897
  const block_q4_1 * restrict vx0 = vx;
3898
+ const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
3899
  const block_q8_1 * restrict vy0 = vy;
3900
+ const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
3901
 
3902
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3903
  float32x4_t summs0 = vdupq_n_f32(0.0f);
 
3908
  const block_q8_1 * restrict b_y0 = &vy0[i];
3909
  const block_q8_1 * restrict b_y1 = &vy1[i];
3910
 
3911
+ float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3912
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3913
+ GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3914
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3915
+ summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
3916
 
3917
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
3918
 
 
3932
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3933
 
3934
  // mmla into int32x4_t
3935
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3936
+ GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3937
+ GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3938
+ GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3939
+ float32x4_t scale = vld1q_f32(_scale);
3940
 
3941
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3942
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 
3955
 
3956
  float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3957
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3958
+ sumv2 = vaddq_f32(sumv2, summs0);
3959
 
3960
  vst1_f32(s, vget_low_f32(sumv2));
3961
  vst1_f32(s + bs, vget_high_f32(sumv2));
 
4839
  #if defined(__ARM_FEATURE_MATMUL_INT8)
4840
  if (nrc == 2) {
4841
  const block_q8_0 * restrict vx0 = vx;
4842
+ const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
4843
  const block_q8_0 * restrict vy0 = vy;
4844
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
4845
 
4846
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4847
 
 
4863
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4864
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4865
 
4866
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4867
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4868
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4869
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4870
+ float32x4_t scale = vld1q_f32(_scale);
4871
 
4872
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4873
  int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));