amritahs-ibm commited on
Commit
d154905
·
1 Parent(s): 3f95f2b

llamafile : ppc64le MMA implementation for Q4_0. (llama/12489)

Browse files

This change upstreams llamafile's cpu matrix
multiplication kernels for ppc64le ISA using MMA
builtins. This patch handles matrix multiplication
between quantised datatypes, block_q4_0 and
block_q8_0.

This change results in 5% - 50% improvement
in total speed(ie all tokens/total time), across
various batch sizes.

The patch is tested with Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf models on a
IBM POWER10 machine.

Signed-off-by: Amrita H S <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-cpu/llamafile/sgemm.cpp +517 -86
ggml/src/ggml-cpu/llamafile/sgemm.cpp CHANGED
@@ -55,6 +55,7 @@
55
 
56
  #include <atomic>
57
  #include <array>
 
58
 
59
  #ifdef _MSC_VER
60
  #define NOINLINE __declspec(noinline)
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
1092
  }
1093
  }
1094
 
1095
- template<typename VA, typename VB>
1096
- void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1097
  int64_t i, j;
1098
  TA *aoffset = NULL;
1099
  VA *vecOffset = NULL;
1100
  TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1101
  TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1103
  VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1104
  VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
1111
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1112
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1113
 
1114
- aoffset = const_cast<TA*>(a);
1115
  vecOffset = vec;
1116
  j = (rows >> 3);
1117
  if (j > 0) {
1118
  do {
1119
- aoffset1 = aoffset;
1120
- aoffset2 = aoffset1 + lda;
1121
- aoffset3 = aoffset2 + lda;
1122
- aoffset4 = aoffset3 + lda;
1123
- aoffset5 = aoffset4 + lda;
1124
- aoffset6 = aoffset5 + lda;
1125
- aoffset7 = aoffset6 + lda;
1126
- aoffset8 = aoffset7 + lda;
1127
- aoffset += 8 * lda;
1128
 
1129
- i = (cols >> 3);
1130
- if (i > 0) {
1131
- do {
1132
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1133
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1134
  C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
1156
  t7 = vec_perm(t2, t4, swiz3);
1157
  t8 = vec_perm(t2, t4, swiz4);
1158
  if (flip == true) {
1159
- t5 = vec_xor(t5, xor_vector);
1160
- t6 = vec_xor(t6, xor_vector);
1161
- t7 = vec_xor(t7, xor_vector);
1162
- t8 = vec_xor(t8, xor_vector);
1163
  }
1164
  vec_xst(t5, 0, vecOffset);
1165
  vec_xst(t6, 0, vecOffset+16);
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
1175
  t7 = vec_perm(t2, t4, swiz3);
1176
  t8 = vec_perm(t2, t4, swiz4);
1177
  if (flip == true) {
1178
- t5 = vec_xor(t5, xor_vector);
1179
- t6 = vec_xor(t6, xor_vector);
1180
- t7 = vec_xor(t7, xor_vector);
1181
- t8 = vec_xor(t8, xor_vector);
1182
  }
1183
  vec_xst(t5, 0, vecOffset+64);
1184
  vec_xst(t6, 0, vecOffset+80);
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
1194
  t7 = vec_perm(t2, t4, swiz3);
1195
  t8 = vec_perm(t2, t4, swiz4);
1196
  if (flip == true) {
1197
- t5 = vec_xor(t5, xor_vector);
1198
- t6 = vec_xor(t6, xor_vector);
1199
- t7 = vec_xor(t7, xor_vector);
1200
- t8 = vec_xor(t8, xor_vector);
1201
  }
1202
  vec_xst(t5, 0, vecOffset+128);
1203
  vec_xst(t6, 0, vecOffset+144);
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
1213
  t7 = vec_perm(t2, t4, swiz3);
1214
  t8 = vec_perm(t2, t4, swiz4);
1215
  if (flip == true) {
1216
- t5 = vec_xor(t5, xor_vector);
1217
- t6 = vec_xor(t6, xor_vector);
1218
- t7 = vec_xor(t7, xor_vector);
1219
- t8 = vec_xor(t8, xor_vector);
1220
  }
1221
  vec_xst(t5, 0, vecOffset+192);
1222
  vec_xst(t6, 0, vecOffset+208);
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
1240
  }
1241
 
1242
  if (rows & 4) {
1243
- aoffset1 = aoffset;
1244
- aoffset2 = aoffset1 + lda;
1245
- aoffset3 = aoffset2 + lda;
1246
- aoffset4 = aoffset3 + lda;
1247
- aoffset += 4 * lda;
1248
 
1249
  i = (cols >> 3);
1250
  if (i > 0) {
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
1311
  aoffset2 = aoffset1 + lda;
1312
  aoffset3 = aoffset2 + lda;
1313
  i = (cols >> 3);
1314
- if (i > 0) {
1315
  do {
1316
  switch(rows) {
1317
  case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
1527
  void KERNEL_4x8(int64_t ii, int64_t jj) {
1528
  vec_t vec_A[8], vec_B[16] = {0};
1529
  acc_t acc_0, acc_1;
1530
- std::array<int, 4> comparray;
1531
  vector float fin_res[8] = {0};
1532
  vector float vs[8] = {0};
 
1533
  for (int l = 0; l < k; l++) {
1534
  __builtin_mma_xxsetaccz(&acc_0);
1535
  __builtin_mma_xxsetaccz(&acc_1);
1536
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
 
 
 
 
1537
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1538
  for(int x = 0; x < 8; x++) {
1539
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
1545
  *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1546
  }
1547
  }
1548
- auto aoffset = A+(ii*lda)+l;
1549
- for (int i = 0; i < 4; i++) {
1550
- comparray[i] = 0;
1551
- int ca = 0;
1552
- const int8_t *at = aoffset->qs;
1553
- for (int j = 0; j < 32; j++)
1554
- ca += (int)*at++;
1555
- comparray[i] = ca;
1556
- aoffset += lda;
 
 
1557
  }
1558
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1559
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
1565
  void KERNEL_8x4(int64_t ii, int64_t jj) {
1566
  vec_t vec_A[16], vec_B[8] = {0};
1567
  acc_t acc_0, acc_1;
1568
- std::array<int, 8> comparray;
1569
  vector float fin_res[8] = {0};
1570
  vector float vs[8] = {0};
 
1571
  for (int l = 0; l < k; l++) {
1572
  __builtin_mma_xxsetaccz(&acc_0);
1573
  __builtin_mma_xxsetaccz(&acc_1);
1574
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
 
 
 
 
1575
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1576
  for(int x = 0; x < 8; x++) {
1577
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
1582
  *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1583
  }
1584
  }
1585
- auto aoffset = A+(ii*lda)+l;
1586
- for (int i = 0; i < 8; i++) {
1587
- comparray[i] = 0;
1588
- int ca = 0;
1589
- const int8_t *at = aoffset->qs;
1590
- for (int j = 0; j < 32; j++)
1591
- ca += (int)*at++;
1592
- comparray[i] = ca;
1593
- aoffset += lda;
 
 
1594
  }
1595
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1596
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
1602
  void KERNEL_8x8(int64_t ii, int64_t jj) {
1603
  vec_t vec_A[16], vec_B[16] = {0};
1604
  acc_t acc_0, acc_1, acc_2, acc_3;
1605
- std::array<int, 8> comparray;
1606
  vector float fin_res[16] = {0};
1607
  vector float vs[16] = {0};
 
1608
  for (int l = 0; l < k; l++) {
1609
  __builtin_mma_xxsetaccz(&acc_0);
1610
  __builtin_mma_xxsetaccz(&acc_1);
1611
  __builtin_mma_xxsetaccz(&acc_2);
1612
  __builtin_mma_xxsetaccz(&acc_3);
1613
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
 
 
 
 
1614
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1615
  for(int x = 0; x < 8; x++) {
1616
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
1624
  *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1625
  }
1626
  }
1627
- auto aoffset = A+(ii*lda)+l;
1628
- for (int i = 0; i < 8; i++) {
1629
- comparray[i] = 0;
1630
- int ca = 0;
1631
- const int8_t *at = aoffset->qs;
1632
- for (int j = 0; j < 32; j++)
1633
- ca += (int)*at++;
1634
- comparray[i] = ca;
1635
- aoffset += lda;
 
 
1636
  }
1637
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1638
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
1653
  int64_t duty = (tiles + nth - 1) / nth;
1654
  int64_t start = duty * ith;
1655
  int64_t end = start + duty;
1656
- vec_t vec_A[8], vec_B[8] = {0};
1657
  vector signed int vec_C[4];
1658
  acc_t acc_0;
 
1659
 
1660
  if (end > tiles)
1661
  end = tiles;
1662
  for (int64_t job = start; job < end; ++job) {
1663
  int64_t ii = m0 + job / xtiles * RM;
1664
  int64_t jj = n0 + job % xtiles * RN;
1665
- std::array<int, RM> comparray;
1666
  vector float res[4] = {0};
1667
  vector float fin_res[4] = {0};
1668
  vector float vs[4] = {0};
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
1673
  __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1675
  __builtin_mma_xxsetaccz(&acc_0);
1676
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
 
 
 
 
1677
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1678
  for(int x = 0; x < 8; x+=4) {
1679
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
1687
  }
1688
  }
1689
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1690
- auto aoffset = A+(ii*lda)+l;
1691
- for (int i = 0; i < RM; i++) {
1692
- comparray[i] = 0;
1693
- int ca = 0;
1694
- const int8_t *at = aoffset->qs;
1695
- for (int j = 0; j < 32; j++)
1696
- ca += (int)*at++;
1697
- comparray[i] = ca;
1698
- aoffset += lda;
 
 
1699
  }
1700
-
1701
  for (int i = 0; i < RM; i++) {
1702
  CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1703
  res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
2013
  }
2014
  }
2015
  }
 
2016
  void KERNEL_4x4(int64_t ii, int64_t jj) {
2017
  vec_t vec_A[4], vec_B[4], vec_C[4];
2018
  acc_t acc_0;
@@ -2259,7 +2678,7 @@ class tinyBLAS_PPC {
2259
  vec_t vec_C[4];
2260
  acc_t acc_0;
2261
  __builtin_mma_xxsetaccz(&acc_0);
2262
- vec_t vec_A[4], vec_B[4];
2263
  for (int l=0; l<k; l+=4) {
2264
  if (RN >= 4 && RM == 1) {
2265
  TA* a = const_cast<TA*>(A+(ii)*lda+l);
@@ -2503,8 +2922,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2503
  params->ith, params->nth};
2504
  tb.matmul(m, n);
2505
  return true;
2506
-
2507
  #elif defined(__MMA__)
 
2508
  if (n < 8 && n != 4)
2509
  return false;
2510
  if (m < 8 && m != 4)
@@ -2516,7 +2935,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2516
  params->ith, params->nth};
2517
  tb.matmul(m, n);
2518
  return true;
2519
-
2520
  #else
2521
  return false;
2522
  #endif
@@ -2541,6 +2959,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2541
  params->ith, params->nth};
2542
  tb.matmul(m, n);
2543
  return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
2544
  #else
2545
  return false;
2546
  #endif
 
55
 
56
  #include <atomic>
57
  #include <array>
58
+ #include <type_traits>
59
 
60
  #ifdef _MSC_VER
61
  #define NOINLINE __declspec(noinline)
 
1093
  }
1094
  }
1095
 
1096
+ template<typename VA, typename VB, int size>
1097
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1098
  int64_t i, j;
1099
  TA *aoffset = NULL;
1100
  VA *vecOffset = NULL;
1101
  TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1102
  TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1103
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1104
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1105
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1106
+ const vector signed char lowMask = vec_splats((signed char)0xF);
1107
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1108
+ const vector signed char v8 = vec_splats((signed char)0x8);
1109
+ aoffset = const_cast<TA*>(a);
1110
+ vecOffset = vec;
1111
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1112
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1113
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1114
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1115
+ vector signed int vsum = {0};
1116
+ vector signed int vsum2 = {0};
1117
+
1118
+ j = (rows >> 3);
1119
+ if (j > 0) {
1120
+ do {
1121
+ aoffset1 = aoffset;
1122
+ aoffset2 = aoffset1 + lda;
1123
+ aoffset3 = aoffset2 + lda;
1124
+ aoffset4 = aoffset3 + lda;
1125
+ aoffset5 = aoffset4 + lda;
1126
+ aoffset6 = aoffset5 + lda;
1127
+ aoffset7 = aoffset6 + lda;
1128
+ aoffset8 = aoffset7 + lda;
1129
+ aoffset += 8 * lda;
1130
+
1131
+ i = (cols >> 2);
1132
+ if (i > 0) {
1133
+ do {
1134
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1135
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1136
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1137
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1138
+ c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1139
+ c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1140
+ c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1141
+ c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1142
+
1143
+ c1[0] = vec_and(c1[1], lowMask);
1144
+ c1[1] = vec_sr(c1[1], v4);
1145
+ c1[0] = vec_sub(c1[0], v8);
1146
+ c1[1] = vec_sub(c1[1], v8);
1147
+ vsum = vec_sum4s(c1[0], vsum);
1148
+ vsum2 = vec_sum4s(c1[1], vsum2);
1149
+ vsum = vec_add(vsum, vsum2);
1150
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1151
+ vsum = vec_splats(0);
1152
+ vsum2 = vec_splats(0);
1153
+
1154
+ c2[0] = vec_and(c2[1], lowMask);
1155
+ c2[1] = vec_sr(c2[1], v4);
1156
+ c2[0] = vec_sub(c2[0], v8);
1157
+ c2[1] = vec_sub(c2[1], v8);
1158
+ vsum = vec_sum4s(c2[0], vsum);
1159
+ vsum2 = vec_sum4s(c2[1], vsum2);
1160
+ vsum = vec_add(vsum, vsum2);
1161
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1162
+ vsum = vec_splats(0);
1163
+ vsum2 = vec_splats(0);
1164
+
1165
+ c3[0] = vec_and(c3[1], lowMask);
1166
+ c3[1] = vec_sr(c3[1], v4);
1167
+ c3[0] = vec_sub(c3[0], v8);
1168
+ c3[1] = vec_sub(c3[1], v8);
1169
+ vsum = vec_sum4s(c3[0], vsum);
1170
+ vsum2 = vec_sum4s(c3[1], vsum2);
1171
+ vsum = vec_add(vsum, vsum2);
1172
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1173
+ vsum = vec_splats(0);
1174
+ vsum2 = vec_splats(0);
1175
+
1176
+ c4[0] = vec_and(c4[1], lowMask);
1177
+ c4[1] = vec_sr(c4[1], v4);
1178
+ c4[0] = vec_sub(c4[0], v8);
1179
+ c4[1] = vec_sub(c4[1], v8);
1180
+ vsum = vec_sum4s(c4[0], vsum);
1181
+ vsum2 = vec_sum4s(c4[1], vsum2);
1182
+ vsum = vec_add(vsum, vsum2);
1183
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1184
+ vsum = vec_splats(0);
1185
+ vsum2 = vec_splats(0);
1186
+
1187
+ c5[0] = vec_and(c5[1], lowMask);
1188
+ c5[1] = vec_sr(c5[1], v4);
1189
+ c5[0] = vec_sub(c5[0], v8);
1190
+ c5[1] = vec_sub(c5[1], v8);
1191
+ vsum = vec_sum4s(c5[0], vsum);
1192
+ vsum2 = vec_sum4s(c5[1], vsum2);
1193
+ vsum = vec_add(vsum, vsum2);
1194
+ comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1195
+ vsum = vec_splats(0);
1196
+ vsum2 = vec_splats(0);
1197
+
1198
+ c6[0] = vec_and(c6[1], lowMask);
1199
+ c6[1] = vec_sr(c6[1], v4);
1200
+ c6[0] = vec_sub(c6[0], v8);
1201
+ c6[1] = vec_sub(c6[1], v8);
1202
+ vsum = vec_sum4s(c6[0], vsum);
1203
+ vsum2 = vec_sum4s(c6[1], vsum2);
1204
+ vsum = vec_add(vsum, vsum2);
1205
+ comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1206
+ vsum = vec_splats(0);
1207
+ vsum2 = vec_splats(0);
1208
+
1209
+ c7[0] = vec_and(c7[1], lowMask);
1210
+ c7[1] = vec_sr(c7[1], v4);
1211
+ c7[0] = vec_sub(c7[0], v8);
1212
+ c7[1] = vec_sub(c7[1], v8);
1213
+ vsum = vec_sum4s(c7[0], vsum);
1214
+ vsum2 = vec_sum4s(c7[1], vsum2);
1215
+ vsum = vec_add(vsum, vsum2);
1216
+ comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1217
+ vsum = vec_splats(0);
1218
+ vsum2 = vec_splats(0);
1219
+
1220
+ c8[0] = vec_and(c8[1], lowMask);
1221
+ c8[1] = vec_sr(c8[1], v4);
1222
+ c8[0] = vec_sub(c8[0], v8);
1223
+ c8[1] = vec_sub(c8[1], v8);
1224
+ vsum = vec_sum4s(c8[0], vsum);
1225
+ vsum2 = vec_sum4s(c8[1], vsum2);
1226
+ vsum = vec_add(vsum, vsum2);
1227
+ comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1228
+ vsum = vec_splats(0);
1229
+ vsum2 = vec_splats(0);
1230
+
1231
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1232
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1233
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1234
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1235
+ t5 = vec_perm(t1, t3, swiz3);
1236
+ t6 = vec_perm(t1, t3, swiz4);
1237
+ t7 = vec_perm(t2, t4, swiz3);
1238
+ t8 = vec_perm(t2, t4, swiz4);
1239
+ vec_xst(t5, 0, vecOffset);
1240
+ vec_xst(t6, 0, vecOffset+16);
1241
+ vec_xst(t7, 0, vecOffset+32);
1242
+ vec_xst(t8, 0, vecOffset+48);
1243
+
1244
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1245
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1246
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1247
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1248
+ t5 = vec_perm(t1, t3, swiz3);
1249
+ t6 = vec_perm(t1, t3, swiz4);
1250
+ t7 = vec_perm(t2, t4, swiz3);
1251
+ t8 = vec_perm(t2, t4, swiz4);
1252
+ vec_xst(t5, 0, vecOffset+64);
1253
+ vec_xst(t6, 0, vecOffset+80);
1254
+ vec_xst(t7, 0, vecOffset+96);
1255
+ vec_xst(t8, 0, vecOffset+112);
1256
+
1257
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1258
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1259
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1260
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1261
+ t5 = vec_perm(t1, t3, swiz3);
1262
+ t6 = vec_perm(t1, t3, swiz4);
1263
+ t7 = vec_perm(t2, t4, swiz3);
1264
+ t8 = vec_perm(t2, t4, swiz4);
1265
+ vec_xst(t5, 0, vecOffset+128);
1266
+ vec_xst(t6, 0, vecOffset+144);
1267
+ vec_xst(t7, 0, vecOffset+160);
1268
+ vec_xst(t8, 0, vecOffset+176);
1269
+
1270
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1271
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1272
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1273
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1274
+ t5 = vec_perm(t1, t3, swiz3);
1275
+ t6 = vec_perm(t1, t3, swiz4);
1276
+ t7 = vec_perm(t2, t4, swiz3);
1277
+ t8 = vec_perm(t2, t4, swiz4);
1278
+ vec_xst(t5, 0, vecOffset+192);
1279
+ vec_xst(t6, 0, vecOffset+208);
1280
+ vec_xst(t7, 0, vecOffset+224);
1281
+ vec_xst(t8, 0, vecOffset+240);
1282
+
1283
+ aoffset1 += lda;
1284
+ aoffset2 += lda;
1285
+ aoffset3 += lda;
1286
+ aoffset4 += lda;
1287
+ aoffset5 += lda;
1288
+ aoffset6 += lda;
1289
+ aoffset7 += lda;
1290
+ aoffset8 += lda;
1291
+ vecOffset += 256;
1292
+ i--;
1293
+ } while (i > 0);
1294
+ }
1295
+ j--;
1296
+ } while (j > 0);
1297
+ }
1298
+
1299
+ if (rows & 4) {
1300
+ aoffset1 = aoffset;
1301
+ aoffset2 = aoffset1 + lda;
1302
+ aoffset3 = aoffset2 + lda;
1303
+ aoffset4 = aoffset3 + lda;
1304
+ aoffset += 4 * lda;
1305
+
1306
+ i = (cols >> 2);
1307
+ if (i > 0) {
1308
+ do {
1309
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1310
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1311
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1312
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1313
+
1314
+ c1[0] = vec_and(c1[1], lowMask);
1315
+ c1[1] = vec_sr(c1[1], v4);
1316
+ c1[0] = vec_sub(c1[0], v8);
1317
+ c1[1] = vec_sub(c1[1], v8);
1318
+ vsum = vec_sum4s(c1[0], vsum);
1319
+ vsum2 = vec_sum4s(c1[1], vsum2);
1320
+ vsum = vec_add(vsum, vsum2);
1321
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1322
+ vsum = vec_splats(0);
1323
+ vsum2 = vec_splats(0);
1324
+
1325
+ c2[0] = vec_and(c2[1], lowMask);
1326
+ c2[1] = vec_sr(c2[1], v4);
1327
+ c2[0] = vec_sub(c2[0], v8);
1328
+ c2[1] = vec_sub(c2[1], v8);
1329
+ vsum = vec_sum4s(c2[0], vsum);
1330
+ vsum2 = vec_sum4s(c2[1], vsum2);
1331
+ vsum = vec_add(vsum, vsum2);
1332
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1333
+ vsum = vec_splats(0);
1334
+ vsum2 = vec_splats(0);
1335
+
1336
+ c3[0] = vec_and(c3[1], lowMask);
1337
+ c3[1] = vec_sr(c3[1], v4);
1338
+ c3[0] = vec_sub(c3[0], v8);
1339
+ c3[1] = vec_sub(c3[1], v8);
1340
+ vsum = vec_sum4s(c3[0], vsum);
1341
+ vsum2 = vec_sum4s(c3[1], vsum2);
1342
+ vsum = vec_add(vsum, vsum2);
1343
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1344
+ vsum = vec_splats(0);
1345
+ vsum2 = vec_splats(0);
1346
+
1347
+ c4[0] = vec_and(c4[1], lowMask);
1348
+ c4[1] = vec_sr(c4[1], v4);
1349
+ c4[0] = vec_sub(c4[0], v8);
1350
+ c4[1] = vec_sub(c4[1], v8);
1351
+ vsum = vec_sum4s(c4[0], vsum);
1352
+ vsum2 = vec_sum4s(c4[1], vsum2);
1353
+ vsum = vec_add(vsum, vsum2);
1354
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1355
+ vsum = vec_splats(0);
1356
+ vsum2 = vec_splats( 0);
1357
+
1358
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1359
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1360
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1361
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1362
+ t5 = vec_perm(t1, t3, swiz3);
1363
+ t6 = vec_perm(t1, t3, swiz4);
1364
+ t7 = vec_perm(t2, t4, swiz3);
1365
+ t8 = vec_perm(t2, t4, swiz4);
1366
+ vec_xst(t5, 0, vecOffset);
1367
+ vec_xst(t6, 0, vecOffset+16);
1368
+ vec_xst(t7, 0, vecOffset+32);
1369
+ vec_xst(t8, 0, vecOffset+48);
1370
+
1371
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1372
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1373
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1374
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1375
+ t5 = vec_perm(t1, t3, swiz3);
1376
+ t6 = vec_perm(t1, t3, swiz4);
1377
+ t7 = vec_perm(t2, t4, swiz3);
1378
+ t8 = vec_perm(t2, t4, swiz4);
1379
+ vec_xst(t5, 0, vecOffset+64);
1380
+ vec_xst(t6, 0, vecOffset+80);
1381
+ vec_xst(t7, 0, vecOffset+96);
1382
+ vec_xst(t8, 0, vecOffset+112);
1383
+
1384
+ aoffset1 += lda;
1385
+ aoffset2 += lda;
1386
+ aoffset3 += lda;
1387
+ aoffset4 += lda;
1388
+ vecOffset += 128;
1389
+ i--;
1390
+ } while (i > 0);
1391
+ }
1392
+ }
1393
+
1394
+ if (rows & 3) {
1395
+ aoffset1 = aoffset;
1396
+ aoffset2 = aoffset1 + lda;
1397
+ aoffset3 = aoffset2 + lda;
1398
+ i = (cols >> 2);
1399
+ if (i > 0) {
1400
+ do {
1401
+ switch(rows) {
1402
+ case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1403
+ case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1404
+ case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1405
+ break;
1406
+ }
1407
+ c1[0] = vec_and(c1[1], lowMask);
1408
+ c1[1] = vec_sr(c1[1], v4);
1409
+ c1[0] = vec_sub(c1[0], v8);
1410
+ c1[1] = vec_sub(c1[1], v8);
1411
+ vsum = vec_sum4s(c1[0], vsum);
1412
+ vsum2 = vec_sum4s(c1[1], vsum2);
1413
+ vsum = vec_add(vsum, vsum2);
1414
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1415
+ vsum = vec_splats(0);
1416
+ vsum2 = vec_splats(0);
1417
+
1418
+ c2[0] = vec_and(c2[1], lowMask);
1419
+ c2[1] = vec_sr(c2[1], v4);
1420
+ c2[0] = vec_sub(c2[0], v8);
1421
+ c2[1] = vec_sub(c2[1], v8);
1422
+ vsum = vec_sum4s(c2[0], vsum);
1423
+ vsum2 = vec_sum4s(c2[1], vsum2);
1424
+ vsum = vec_add(vsum, vsum2);
1425
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1426
+ vsum = vec_splats(0);
1427
+ vsum2 = vec_splats(0);
1428
+
1429
+ c3[0] = vec_and(c3[1], lowMask);
1430
+ c3[1] = vec_sr(c3[1], v4);
1431
+ c3[0] = vec_sub(c3[0], v8);
1432
+ c3[1] = vec_sub(c3[1], v8);
1433
+ vsum = vec_sum4s(c3[0], vsum);
1434
+ vsum2 = vec_sum4s(c3[1], vsum2);
1435
+ vsum = vec_add(vsum, vsum2);
1436
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1437
+ vsum = vec_splats(0);
1438
+ vsum2 = vec_splats(0);
1439
+
1440
+ c4[0] = vec_and(c4[1], lowMask);
1441
+ c4[1] = vec_sr(c4[1], v4);
1442
+ c4[0] = vec_sub(c4[0], v8);
1443
+ c4[1] = vec_sub(c4[1], v8);
1444
+ vsum = vec_sum4s(c4[0], vsum);
1445
+ vsum2 = vec_sum4s(c4[1], vsum2);
1446
+ vsum = vec_add(vsum, vsum2);
1447
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1448
+ vsum = vec_splats(0);
1449
+ vsum2 = vec_splats(0);
1450
+
1451
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1452
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1453
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1454
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1455
+ t5 = vec_perm(t1, t3, swiz3);
1456
+ t6 = vec_perm(t1, t3, swiz4);
1457
+ t7 = vec_perm(t2, t4, swiz3);
1458
+ t8 = vec_perm(t2, t4, swiz4);
1459
+ vec_xst(t5, 0, vecOffset);
1460
+ vec_xst(t6, 0, vecOffset+16);
1461
+ vec_xst(t7, 0, vecOffset+32);
1462
+ vec_xst(t8, 0, vecOffset+48);
1463
+
1464
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1465
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1466
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1467
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1468
+ t5 = vec_perm(t1, t3, swiz3);
1469
+ t6 = vec_perm(t1, t3, swiz4);
1470
+ t7 = vec_perm(t2, t4, swiz3);
1471
+ t8 = vec_perm(t2, t4, swiz4);
1472
+ vec_xst(t5, 0, vecOffset+64);
1473
+ vec_xst(t6, 0, vecOffset+80);
1474
+ vec_xst(t7, 0, vecOffset+96);
1475
+ vec_xst(t8, 0, vecOffset+112);
1476
+ aoffset1 += lda;
1477
+ aoffset2 += lda;
1478
+ aoffset3 += lda;
1479
+ vecOffset += 128;
1480
+ i--;
1481
+ } while(i > 0);
1482
+ }
1483
+ }
1484
+ }
1485
+
1486
+ template<typename VA, typename VB>
1487
+ void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1488
+ int64_t i, j;
1489
+ TB *aoffset = NULL;
1490
+ VA *vecOffset = NULL;
1491
+ TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1492
+ TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1493
  __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1494
  VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1495
  VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
 
1502
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1503
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1504
 
1505
+ aoffset = const_cast<TB*>(a);
1506
  vecOffset = vec;
1507
  j = (rows >> 3);
1508
  if (j > 0) {
1509
  do {
1510
+ aoffset1 = aoffset;
1511
+ aoffset2 = aoffset1 + lda;
1512
+ aoffset3 = aoffset2 + lda;
1513
+ aoffset4 = aoffset3 + lda;
1514
+ aoffset5 = aoffset4 + lda;
1515
+ aoffset6 = aoffset5 + lda;
1516
+ aoffset7 = aoffset6 + lda;
1517
+ aoffset8 = aoffset7 + lda;
1518
+ aoffset += 8 * lda;
1519
 
1520
+ i = (cols >> 3);
1521
+ if (i > 0) {
1522
+ do {
1523
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1524
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1525
  C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
 
1547
  t7 = vec_perm(t2, t4, swiz3);
1548
  t8 = vec_perm(t2, t4, swiz4);
1549
  if (flip == true) {
1550
+ t5 = vec_xor(t5, xor_vector);
1551
+ t6 = vec_xor(t6, xor_vector);
1552
+ t7 = vec_xor(t7, xor_vector);
1553
+ t8 = vec_xor(t8, xor_vector);
1554
  }
1555
  vec_xst(t5, 0, vecOffset);
1556
  vec_xst(t6, 0, vecOffset+16);
 
1566
  t7 = vec_perm(t2, t4, swiz3);
1567
  t8 = vec_perm(t2, t4, swiz4);
1568
  if (flip == true) {
1569
+ t5 = vec_xor(t5, xor_vector);
1570
+ t6 = vec_xor(t6, xor_vector);
1571
+ t7 = vec_xor(t7, xor_vector);
1572
+ t8 = vec_xor(t8, xor_vector);
1573
  }
1574
  vec_xst(t5, 0, vecOffset+64);
1575
  vec_xst(t6, 0, vecOffset+80);
 
1585
  t7 = vec_perm(t2, t4, swiz3);
1586
  t8 = vec_perm(t2, t4, swiz4);
1587
  if (flip == true) {
1588
+ t5 = vec_xor(t5, xor_vector);
1589
+ t6 = vec_xor(t6, xor_vector);
1590
+ t7 = vec_xor(t7, xor_vector);
1591
+ t8 = vec_xor(t8, xor_vector);
1592
  }
1593
  vec_xst(t5, 0, vecOffset+128);
1594
  vec_xst(t6, 0, vecOffset+144);
 
1604
  t7 = vec_perm(t2, t4, swiz3);
1605
  t8 = vec_perm(t2, t4, swiz4);
1606
  if (flip == true) {
1607
+ t5 = vec_xor(t5, xor_vector);
1608
+ t6 = vec_xor(t6, xor_vector);
1609
+ t7 = vec_xor(t7, xor_vector);
1610
+ t8 = vec_xor(t8, xor_vector);
1611
  }
1612
  vec_xst(t5, 0, vecOffset+192);
1613
  vec_xst(t6, 0, vecOffset+208);
 
1631
  }
1632
 
1633
  if (rows & 4) {
1634
+ aoffset1 = aoffset;
1635
+ aoffset2 = aoffset1 + lda;
1636
+ aoffset3 = aoffset2 + lda;
1637
+ aoffset4 = aoffset3 + lda;
1638
+ aoffset += 4 * lda;
1639
 
1640
  i = (cols >> 3);
1641
  if (i > 0) {
 
1702
  aoffset2 = aoffset1 + lda;
1703
  aoffset3 = aoffset2 + lda;
1704
  i = (cols >> 3);
1705
+ if (i > 0) {
1706
  do {
1707
  switch(rows) {
1708
  case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
 
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
1919
  vec_t vec_A[8], vec_B[16] = {0};
1920
  acc_t acc_0, acc_1;
1921
+ std::array<int, 4> comparray {};
1922
  vector float fin_res[8] = {0};
1923
  vector float vs[8] = {0};
1924
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1925
  for (int l = 0; l < k; l++) {
1926
  __builtin_mma_xxsetaccz(&acc_0);
1927
  __builtin_mma_xxsetaccz(&acc_1);
1928
+ if (std::is_same_v<TA, block_q4_0>) {
1929
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1930
+ } else {
1931
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1932
+ }
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1934
  for(int x = 0; x < 8; x++) {
1935
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
 
1941
  *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1942
  }
1943
  }
1944
+ if (!isAblock_q4) {
1945
+ auto aoffset = A+(ii*lda)+l;
1946
+ for (int i = 0; i < 4; i++) {
1947
+ comparray[i] = 0;
1948
+ int ca = 0;
1949
+ auto *at = aoffset->qs;
1950
+ for (int j = 0; j < 32; j++)
1951
+ ca += (int)*at++;
1952
+ comparray[i] = ca;
1953
+ aoffset += lda;
1954
+ }
1955
  }
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
 
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
1964
  vec_t vec_A[16], vec_B[8] = {0};
1965
  acc_t acc_0, acc_1;
1966
+ std::array<int, 8> comparray {};
1967
  vector float fin_res[8] = {0};
1968
  vector float vs[8] = {0};
1969
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1970
  for (int l = 0; l < k; l++) {
1971
  __builtin_mma_xxsetaccz(&acc_0);
1972
  __builtin_mma_xxsetaccz(&acc_1);
1973
+ if (std::is_same_v<TA, block_q4_0>) {
1974
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1975
+ } else {
1976
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1977
+ }
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1979
  for(int x = 0; x < 8; x++) {
1980
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
 
1985
  *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1986
  }
1987
  }
1988
+ if (!isAblock_q4) {
1989
+ auto aoffset = A+(ii*lda)+l;
1990
+ for (int i = 0; i < 8; i++) {
1991
+ comparray[i] = 0;
1992
+ int ca = 0;
1993
+ auto *at = aoffset->qs;
1994
+ for (int j = 0; j < 32; j++)
1995
+ ca += (int)*at++;
1996
+ comparray[i] = ca;
1997
+ aoffset += lda;
1998
+ }
1999
  }
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
 
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
2008
  vec_t vec_A[16], vec_B[16] = {0};
2009
  acc_t acc_0, acc_1, acc_2, acc_3;
2010
+ std::array<int, 8> comparray {};
2011
  vector float fin_res[16] = {0};
2012
  vector float vs[16] = {0};
2013
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2014
  for (int l = 0; l < k; l++) {
2015
  __builtin_mma_xxsetaccz(&acc_0);
2016
  __builtin_mma_xxsetaccz(&acc_1);
2017
  __builtin_mma_xxsetaccz(&acc_2);
2018
  __builtin_mma_xxsetaccz(&acc_3);
2019
+ if (std::is_same_v<TA, block_q4_0>) {
2020
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2021
+ } else {
2022
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2023
+ }
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2025
  for(int x = 0; x < 8; x++) {
2026
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
 
2034
  *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2035
  }
2036
  }
2037
+ if (!isAblock_q4) {
2038
+ auto aoffset = A+(ii*lda)+l;
2039
+ for (int i = 0; i < 8; i++) {
2040
+ comparray[i] = 0;
2041
+ int ca = 0;
2042
+ auto *at = aoffset->qs;
2043
+ for (int j = 0; j < 32; j++)
2044
+ ca += (int)*at++;
2045
+ comparray[i] = ca;
2046
+ aoffset += lda;
2047
+ }
2048
  }
2049
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2050
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
 
2065
  int64_t duty = (tiles + nth - 1) / nth;
2066
  int64_t start = duty * ith;
2067
  int64_t end = start + duty;
2068
+ vec_t vec_A[8] = {0}, vec_B[8] = {0};
2069
  vector signed int vec_C[4];
2070
  acc_t acc_0;
2071
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2072
 
2073
  if (end > tiles)
2074
  end = tiles;
2075
  for (int64_t job = start; job < end; ++job) {
2076
  int64_t ii = m0 + job / xtiles * RM;
2077
  int64_t jj = n0 + job % xtiles * RN;
2078
+ std::array<int, 4> comparray{};
2079
  vector float res[4] = {0};
2080
  vector float fin_res[4] = {0};
2081
  vector float vs[4] = {0};
 
2086
  __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2087
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2088
  __builtin_mma_xxsetaccz(&acc_0);
2089
+ if (isAblock_q4) {
2090
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2091
+ } else {
2092
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2093
+ }
2094
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2095
  for(int x = 0; x < 8; x+=4) {
2096
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
 
2104
  }
2105
  }
2106
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
2107
+ if (!isAblock_q4) {
2108
+ auto aoffset = A+(ii*lda)+l;
2109
+ for (int i = 0; i < RM; i++) {
2110
+ comparray[i] = 0;
2111
+ int ca = 0;
2112
+ auto *at = aoffset->qs;
2113
+ for (int j = 0; j < 32; j++)
2114
+ ca += (int)*at++;
2115
+ comparray[i] = ca;
2116
+ aoffset += lda;
2117
+ }
2118
  }
 
2119
  for (int i = 0; i < RM; i++) {
2120
  CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
2121
  res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
 
2431
  }
2432
  }
2433
  }
2434
+
2435
  void KERNEL_4x4(int64_t ii, int64_t jj) {
2436
  vec_t vec_A[4], vec_B[4], vec_C[4];
2437
  acc_t acc_0;
 
2678
  vec_t vec_C[4];
2679
  acc_t acc_0;
2680
  __builtin_mma_xxsetaccz(&acc_0);
2681
+ vec_t vec_A[4] {0}, vec_B[4] = {0};
2682
  for (int l=0; l<k; l+=4) {
2683
  if (RN >= 4 && RM == 1) {
2684
  TA* a = const_cast<TA*>(A+(ii)*lda+l);
 
2922
  params->ith, params->nth};
2923
  tb.matmul(m, n);
2924
  return true;
 
2925
  #elif defined(__MMA__)
2926
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
2927
  if (n < 8 && n != 4)
2928
  return false;
2929
  if (m < 8 && m != 4)
 
2935
  params->ith, params->nth};
2936
  tb.matmul(m, n);
2937
  return true;
 
2938
  #else
2939
  return false;
2940
  #endif
 
2959
  params->ith, params->nth};
2960
  tb.matmul(m, n);
2961
  return true;
2962
+ #elif defined(__MMA__)
2963
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
2964
+ if (n < 8 && n != 4)
2965
+ return false;
2966
+ if (m < 8 && m != 4)
2967
+ return false;
2968
+ tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
2969
+ k, (const block_q4_0 *)A, lda,
2970
+ (const block_q8_0 *)B, ldb,
2971
+ (float *)C, ldc,
2972
+ params->ith, params->nth};
2973
+ tb.matmul(m, n);
2974
+ return true;
2975
  #else
2976
  return false;
2977
  #endif