shalinib commited on
Commit
7da5bcc
·
1 Parent(s): 31cad24

ggml : Enable MMA for BF16 in llamafile_sgemm (llama/13148)

Browse files

This patch upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for BF16 data type.

This change results in 9x - 40x gains
in total speed S t/s (ie all tokens/total time), across various batch sizes tested using llama-batched-bench benchmark.

The patch is tested with Meta-Lllama-3-8B,
and Mistral-7B models (BF16 models generated by using llama-quantize from corresponding FP32 models) on an IBM POWER10 machine.

Signed-off-by: Shalini Salomi Bodapati <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
ggml/src/ggml-cpu/llamafile/sgemm.cpp CHANGED
@@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
1054
  } \
1055
  } \
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  template <typename TA, typename TB, typename TC>
1058
  class tinyBLAS_Q0_PPC {
1059
  public:
@@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
2202
  boffset = vec;
2203
  j = (rows >> 3);
2204
  if (j > 0) {
 
2205
  do {
2206
  aoffset1 = aoffset;
2207
  aoffset2 = aoffset1 + lda;
@@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2875
  (float *)C, ldc};
2876
  return tb.matmul(m, n);
2877
  }
 
 
 
 
 
 
 
 
 
 
 
 
2878
  #endif
2879
  return false;
2880
  }
 
2881
  case GGML_TYPE_F16: {
2882
  #if defined(__AVX512F__)
2883
  if (Btype == GGML_TYPE_F16) {
 
1054
  } \
1055
  } \
1056
 
1057
+ template <typename TA, typename TB, typename TC>
1058
+ class tinyBLAS_BF16_PPC {
1059
+ public:
1060
+ tinyBLAS_BF16_PPC(int64_t k,
1061
+ const TA *A, int64_t lda,
1062
+ const TB *B, int64_t ldb,
1063
+ TC *C, int64_t ldc,
1064
+ int ith, int nth)
1065
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1066
+ }
1067
+
1068
+ void matmul(int64_t m, int64_t n) {
1069
+ mnpack(0, m, 0, n);
1070
+ }
1071
+
1072
+ private:
1073
+ void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1074
+ vec_t t[8], s[8];
1075
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1076
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1077
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1078
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1079
+
1080
+ if (numVec == 2) {
1081
+ t[0] = vec_perm(c[0], c[1], swiz1);
1082
+ t[1] = vec_perm(c[2], c[3], swiz1);
1083
+ s[0] = vec_perm(t[0], t[1], swiz3);
1084
+ s[1] = vec_perm(t[0], t[1], swiz4);
1085
+ vec_xst(s[0], 0, (vec_t*)vecOffset);
1086
+ vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1087
+ } else if (numVec == 4) {
1088
+ t[0] = vec_perm(c[0], c[1], swiz1);
1089
+ t[1] = vec_perm(c[0], c[1], swiz2);
1090
+ t[2] = vec_perm(c[2], c[3], swiz1);
1091
+ t[3] = vec_perm(c[2], c[3], swiz2);
1092
+ s[0] = vec_perm(t[0], t[2], swiz3);
1093
+ s[1] = vec_perm(t[0], t[2], swiz4);
1094
+ s[2] = vec_perm(t[1], t[3], swiz3);
1095
+ s[3] = vec_perm(t[1], t[3], swiz4);
1096
+ for (int i = 0; i < 4; ++i)
1097
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1098
+ } else if (numVec == 8) {
1099
+ for (int i = 0; i < 4; i += 2) {
1100
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1101
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1102
+ }
1103
+ for (int i = 4; i < 8; i += 2) {
1104
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1105
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1106
+ }
1107
+ s[0] = vec_perm(t[0], t[2], swiz3);
1108
+ s[1] = vec_perm(t[0], t[2], swiz4);
1109
+ s[2] = vec_perm(t[1], t[3], swiz3);
1110
+ s[3] = vec_perm(t[1], t[3], swiz4);
1111
+ s[4] = vec_perm(t[4], t[6], swiz3);
1112
+ s[5] = vec_perm(t[4], t[6], swiz4);
1113
+ s[6] = vec_perm(t[5], t[7], swiz3);
1114
+ s[7] = vec_perm(t[5], t[7], swiz4);
1115
+ for (int i = 0; i < 8; ++i)
1116
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1117
+ }
1118
+ }
1119
+
1120
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1121
+ int64_t i, j;
1122
+ TA *aoffset = NULL;
1123
+ unsigned char *vecOffset = NULL;
1124
+ TA * aoffsets[8];
1125
+ vector unsigned char c_arr[8];
1126
+ aoffset = const_cast<TA*>(a);
1127
+ vecOffset = vec;
1128
+ j = (rows >> 3);
1129
+ if (j > 0) {
1130
+ do {
1131
+ if (cols == 4) {
1132
+ aoffsets[0] = aoffset;
1133
+ for (int it = 1; it < 4; ++it)
1134
+ aoffsets[it] = aoffsets[it-1] + lda;
1135
+ aoffset += 4 * lda;
1136
+ for (int i = 0; i < 4; ++i)
1137
+ c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1138
+ vector_permute_store(c_arr, 4, vecOffset);
1139
+ for (int i = 0; i<4; i++)
1140
+ aoffsets[i] = aoffsets[i]+lda;
1141
+ vecOffset +=64;
1142
+ }
1143
+ i = (cols >> 3);
1144
+ if (i > 0) {
1145
+ aoffsets[0] = aoffset;
1146
+ for (int it = 1; it < 8; ++it) {
1147
+ aoffsets[it] = aoffsets[it-1] + lda;
1148
+ }
1149
+ aoffset += 8 * lda;
1150
+ do {
1151
+ for (int it = 0; it < 8; ++it)
1152
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1153
+ vector_permute_store(c_arr, 8, vecOffset);
1154
+ for (int it = 0; it < 8; ++it)
1155
+ aoffsets[it] = aoffsets[it] + 8*lda;
1156
+ vecOffset += 128;
1157
+ i--;
1158
+ } while(i > 0);
1159
+ }
1160
+ j--;
1161
+ } while(j > 0);
1162
+ }
1163
+ if (rows & 4) {
1164
+ aoffsets[0] = aoffset;
1165
+ for (int it = 1; it < 4; ++it)
1166
+ aoffsets[it] = aoffsets[it-1] + lda;
1167
+ aoffset += 4 * lda;
1168
+ if (cols == 4) {
1169
+ for (int it = 0; it < 4; ++it)
1170
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1171
+ vector_permute_store(c_arr, 2, vecOffset);
1172
+ for (int it = 0; it< 4; it++)
1173
+ aoffsets[it] = aoffsets[it] + lda;
1174
+ vecOffset += 32;
1175
+ }
1176
+ i = (cols >> 3);
1177
+ if (i > 0) {
1178
+ do {
1179
+ for (int it = 0; it < 4; ++it)
1180
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1181
+ vector_permute_store(c_arr, 4, vecOffset);
1182
+ for (int it = 0; it< 4; it++)
1183
+ aoffsets[it] = aoffsets[it] + 8*lda;
1184
+ vecOffset += 64;
1185
+ i--;
1186
+ } while(i > 0);
1187
+ }
1188
+ }
1189
+ if (rows & 3) {
1190
+ aoffsets[0] = aoffset;
1191
+ for (int it = 1; it < 4; ++it)
1192
+ aoffsets[it] = aoffsets[it-1] + lda;
1193
+ if (cols == 4) {
1194
+ switch(rows) {
1195
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1196
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1197
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1198
+ break;
1199
+ }
1200
+ vector_permute_store(c_arr, 2, vecOffset);
1201
+ for (int it = 0; it< 4; it++)
1202
+ aoffsets[it] = aoffsets[it] + lda;
1203
+ vecOffset += 32;
1204
+ }
1205
+ i = (cols >> 3);
1206
+ if (i > 0) {
1207
+ do {
1208
+ switch(rows) {
1209
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1210
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1211
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1212
+ break;
1213
+ }
1214
+ vector_permute_store(c_arr, 4, vecOffset);
1215
+ for (int it = 0; it <4; it++)
1216
+ aoffsets[it] = aoffsets[it] + 8* lda;
1217
+ vecOffset += 64;
1218
+ i--;
1219
+ } while(i > 0);
1220
+ }
1221
+ }
1222
+ }
1223
+
1224
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1225
+ int64_t mc, nc, mp, np;
1226
+ int m_rem = MIN(m - m0, 8);
1227
+ int n_rem = MIN(n - n0, 8);
1228
+
1229
+ if (m_rem >= 8 && n_rem >= 8) {
1230
+ mc = 8;
1231
+ nc = 8;
1232
+ gemm<8,8>(m0, m, n0, n);
1233
+ } else if (m_rem >= 4 && n_rem >= 8) {
1234
+ mc = 4;
1235
+ nc = 8;
1236
+ gemm<4,8>(m0, m, n0, n);
1237
+ } else if (m_rem >=8 && n_rem >=4){
1238
+ mc = 8;
1239
+ nc = 4;
1240
+ gemm<8,4>(m0, m, n0, n);
1241
+ } else if ((m_rem < 4) && (n_rem >= 8)) {
1242
+ nc = 8;
1243
+ switch(m_rem) {
1244
+ case 1:
1245
+ mc = 1;
1246
+ gemm_Mx8<1>(m0, m, n0, n);
1247
+ break;
1248
+ case 2:
1249
+ mc = 2;
1250
+ gemm_Mx8<2>(m0, m, n0, n);
1251
+ break;
1252
+ case 3:
1253
+ mc = 3;
1254
+ gemm_Mx8<3>(m0, m, n0, n);
1255
+ break;
1256
+ default:
1257
+ return;
1258
+ }
1259
+ } else if (m_rem >= 4 && n_rem >= 4) {
1260
+ mc = 4;
1261
+ nc = 4;
1262
+ gemm_small<4, 4>(m0, m, n0, n);
1263
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1264
+ mc = 4;
1265
+ switch(n_rem) {
1266
+ case 1:
1267
+ nc = 1;
1268
+ gemm_small<4, 1>(m0, m, n0, n);
1269
+ break;
1270
+ case 2:
1271
+ nc = 2;
1272
+ gemm_small<4, 2>(m0, m, n0, n);
1273
+ break;
1274
+ case 3:
1275
+ nc = 3;
1276
+ gemm_small<4, 3>(m0, m, n0, n);
1277
+ break;
1278
+
1279
+ default:
1280
+ return;
1281
+ }
1282
+ } else {
1283
+ switch((m_rem << 4) | n_rem) {
1284
+ case 0x43:
1285
+ mc = 4;
1286
+ nc = 3;
1287
+ gemm_small<4, 3>(m0, m, n0, n);
1288
+ break;
1289
+ case 0x42:
1290
+ mc = 4;
1291
+ nc = 2;
1292
+ gemm_small<4, 2>(m0, m, n0, n);
1293
+ break;
1294
+ case 0x41:
1295
+ mc = 4;
1296
+ nc = 1;
1297
+ gemm_small<4, 1>(m0, m, n0, n);
1298
+ break;
1299
+ case 0x34:
1300
+ mc = 3;
1301
+ nc = 4;
1302
+ gemm_small<3, 4>(m0, m, n0, n);
1303
+ break;
1304
+ case 0x33:
1305
+ mc = 3;
1306
+ nc = 3;
1307
+ gemm_small<3, 3>(m0, m, n0, n);
1308
+ break;
1309
+ case 0x32:
1310
+ mc = 3;
1311
+ nc = 2;
1312
+ gemm_small<3, 2>(m0, m, n0, n);
1313
+ break;
1314
+ case 0x31:
1315
+ mc = 3;
1316
+ nc = 1;
1317
+ gemm_small<3, 1>(m0, m, n0, n);
1318
+ break;
1319
+ case 0x24:
1320
+ mc = 2;
1321
+ nc = 4;
1322
+ gemm_small<2,4>(m0, m, n0, n);
1323
+ break;
1324
+ case 0x23:
1325
+ mc = 2;
1326
+ nc = 3;
1327
+ gemm_small<2, 3>(m0, m, n0, n);
1328
+ break;
1329
+ case 0x22:
1330
+ mc = 2;
1331
+ nc = 2;
1332
+ gemm_small<2, 2>(m0, m, n0, n);
1333
+ break;
1334
+ case 0x21:
1335
+ mc = 2;
1336
+ nc = 1;
1337
+ gemm_small<2, 1>(m0, m, n0, n);
1338
+ break;
1339
+ case 0x14:
1340
+ mc = 1;
1341
+ nc = 4;
1342
+ gemm_small<1, 4>(m0, m, n0, n);
1343
+ break;
1344
+ case 0x13:
1345
+ mc = 1;
1346
+ nc = 3;
1347
+ gemm_small<1, 3>(m0, m, n0, n);
1348
+ break;
1349
+ case 0x12:
1350
+ mc = 1;
1351
+ nc = 2;
1352
+ gemm_small<1, 2>(m0, m, n0, n);
1353
+ break;
1354
+ case 0x11:
1355
+ mc = 1;
1356
+ nc = 1;
1357
+ gemm_small<1, 1>(m0, m, n0, n);
1358
+ break;
1359
+ default:
1360
+ return;
1361
+ }
1362
+ }
1363
+ mp = m0 + (m - m0) / mc * mc;
1364
+ np = n0 + (n - n0) / nc * nc;
1365
+ mnpack(mp, m, n0, np);
1366
+ mnpack(m0, m, np, n);
1367
+ }
1368
+
1369
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1370
+ vec_t vec_A[4], vec_B[8] , vec_C[4];
1371
+ acc_t acc_0, acc_1;
1372
+ __builtin_mma_xxsetaccz(&acc_0);
1373
+ __builtin_mma_xxsetaccz(&acc_1);
1374
+ for (int l = 0; l < k; l+=8) {
1375
+ packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1376
+ packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1377
+ for (int x = 0; x < 4; x++) {
1378
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1379
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1380
+ }
1381
+ }
1382
+ SAVE_ACC(&acc_0, ii, jj);
1383
+ SAVE_ACC(&acc_1, ii, jj+4);
1384
+ }
1385
+
1386
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1387
+ vec_t vec_A[8], vec_B[4] , vec_C[4];
1388
+ acc_t acc_0, acc_1;
1389
+ __builtin_mma_xxsetaccz(&acc_0);
1390
+ __builtin_mma_xxsetaccz(&acc_1);
1391
+ for (int l = 0; l < k; l+=8) {
1392
+ packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1393
+ packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1394
+ for (int x = 0; x < 4; x++) {
1395
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1396
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1397
+ }
1398
+ }
1399
+ SAVE_ACC(&acc_0, ii, jj);
1400
+ SAVE_ACC(&acc_1, ii+4, jj);
1401
+ }
1402
+
1403
+
1404
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1405
+ vec_t vec_A[8], vec_B[8], vec_C[4];
1406
+ acc_t acc_0, acc_1, acc_2, acc_3;
1407
+ __builtin_mma_xxsetaccz(&acc_0);
1408
+ __builtin_mma_xxsetaccz(&acc_1);
1409
+ __builtin_mma_xxsetaccz(&acc_2);
1410
+ __builtin_mma_xxsetaccz(&acc_3);
1411
+ for (int l = 0; l < k; l+=8) {
1412
+ packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1413
+ packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1414
+ for (int x = 0; x < 4; x++) {
1415
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1416
+ __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1417
+ __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1418
+ __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1419
+ }
1420
+ }
1421
+
1422
+ SAVE_ACC(&acc_0, ii, jj);
1423
+ SAVE_ACC(&acc_1, ii, jj+4);
1424
+ SAVE_ACC(&acc_2, ii+4, jj);
1425
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1426
+ }
1427
+
1428
+ template<int RM, int RN>
1429
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1430
+ int64_t ytiles = (m - m0) / RM;
1431
+ int64_t xtiles = (n - n0) / RN;
1432
+ int64_t tiles = xtiles * ytiles;
1433
+ int64_t duty = (tiles + nth - 1) / nth;
1434
+ int64_t start = duty * ith;
1435
+ int64_t end = start + duty;
1436
+ if (end > tiles)
1437
+ end = tiles;
1438
+ for (int64_t job = start; job < end; ++job) {
1439
+ int64_t ii = m0 + job / xtiles * RM;
1440
+ int64_t jj = n0 + job % xtiles * RN;
1441
+ vec_t vec_C[4];
1442
+ acc_t acc_0;
1443
+ __builtin_mma_xxsetaccz(&acc_0);
1444
+ vec_t vec_A[2], vec_B[2];
1445
+ for (int l=0; l<k; l+=4) {
1446
+ packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1447
+ packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1448
+ for (int x = 0; x<2; x++) {
1449
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1450
+ }
1451
+ }
1452
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1453
+ for (int I = 0; I < RM; I++) {
1454
+ for (int J = 0; J < RN; J++) {
1455
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+
1461
+ template<int RM>
1462
+ void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1463
+ int RN = 8;
1464
+ int64_t ytiles = (m - m0) / RM;
1465
+ int64_t xtiles = (n - n0) / RN;
1466
+ int64_t tiles = xtiles * ytiles;
1467
+ int64_t duty = (tiles + nth - 1) / nth;
1468
+ int64_t start = duty * ith;
1469
+ int64_t end = start + duty;
1470
+ if (end > tiles)
1471
+ end = tiles;
1472
+ for (int64_t job = start; job < end; ++job) {
1473
+ int64_t ii = m0 + job / xtiles * RM;
1474
+ int64_t jj = n0 + job % xtiles * RN;
1475
+ vec_t vec_C[4];
1476
+ acc_t acc_0, acc_1;
1477
+ __builtin_mma_xxsetaccz(&acc_0);
1478
+ __builtin_mma_xxsetaccz(&acc_1);
1479
+ vec_t vec_A[4], vec_B[8];
1480
+ for (int l=0; l<k; l+=8) {
1481
+ packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1482
+ packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1483
+ for (int x = 0; x<4; x++) {
1484
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1485
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1486
+ }
1487
+ }
1488
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1489
+ for (int I = 0; I < RM; I++) {
1490
+ for (int J = 0; J < 4; J++) {
1491
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1492
+ }
1493
+ }
1494
+ __builtin_mma_disassemble_acc(vec_C, &acc_1);
1495
+ for (int I = 0; I < RM; I++) {
1496
+ for (int J = 0; J < 4; J++) {
1497
+ *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1498
+ }
1499
+ }
1500
+ }
1501
+ }
1502
+
1503
+ template<int RM, int RN>
1504
+ inline void kernel(int64_t ii, int64_t jj) {
1505
+ if constexpr(RM == 4 && RN == 8) {
1506
+ KERNEL_4x8(ii,jj);
1507
+ } else if constexpr(RM == 8 && RN == 8) {
1508
+ KERNEL_8x8(ii,jj);
1509
+ } else if constexpr(RM == 8 && RN == 4) {
1510
+ KERNEL_8x4(ii,jj);
1511
+ } else {
1512
+ static_assert(false, "RN/RM values not supported");
1513
+ }
1514
+ }
1515
+
1516
+ template <int RM, int RN>
1517
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1518
+ int64_t ytiles = (m - m0) / RM;
1519
+ int64_t xtiles = (n - n0) / RN;
1520
+ int64_t tiles = xtiles * ytiles;
1521
+ int64_t duty = (tiles + nth - 1) / nth;
1522
+ int64_t start = duty * ith;
1523
+ int64_t end = start + duty;
1524
+ if (end > tiles)
1525
+ end = tiles;
1526
+ for (int64_t job = start; job < end; ++job) {
1527
+ int64_t ii = m0 + job / xtiles * RM;
1528
+ int64_t jj = n0 + job % xtiles * RN;
1529
+ kernel<RM, RN>(ii, jj);
1530
+ }
1531
+ }
1532
+
1533
+ const TA *const A;
1534
+ const TB *const B;
1535
+ TC *C;
1536
+ const int64_t k;
1537
+ const int64_t lda;
1538
+ const int64_t ldb;
1539
+ const int64_t ldc;
1540
+ const int ith;
1541
+ const int nth;
1542
+ };
1543
+
1544
  template <typename TA, typename TB, typename TC>
1545
  class tinyBLAS_Q0_PPC {
1546
  public:
 
2689
  boffset = vec;
2690
  j = (rows >> 3);
2691
  if (j > 0) {
2692
+
2693
  do {
2694
  aoffset1 = aoffset;
2695
  aoffset2 = aoffset1 + lda;
 
3363
  (float *)C, ldc};
3364
  return tb.matmul(m, n);
3365
  }
3366
+ #elif defined(__MMA__)
3367
+ if ((k % 8))
3368
+ return false;
3369
+ if(Btype == GGML_TYPE_BF16) {
3370
+ tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3371
+ (const ggml_bf16_t *)A, lda,
3372
+ (const ggml_bf16_t *)B, ldb,
3373
+ (float *)C, ldc,
3374
+ params->ith, params->nth};
3375
+ tb.matmul(m, n);
3376
+ return true;
3377
+ }
3378
  #endif
3379
  return false;
3380
  }
3381
+
3382
  case GGML_TYPE_F16: {
3383
  #if defined(__AVX512F__)
3384
  if (Btype == GGML_TYPE_F16) {