Justine Tunney commited on
Commit
bfe2a5f
·
1 Parent(s): 12af87c

llamafile : improve sgemm.cpp (llama/6796)

Browse files

* llamafile : improve sgemm.cpp

- Re-enable by default
- Fix issue described in #6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases

* Address review comments

* Help clang produce fma instructions

* Address review comments

Files changed (1) hide show
  1. ggml.c +3 -5
ggml.c CHANGED
@@ -10887,7 +10887,7 @@ static void ggml_compute_forward_mul_mat(
10887
  #endif
10888
 
10889
  #if GGML_USE_LLAMAFILE
10890
- if (nb10 == ggml_type_size(src1->type)) {
10891
  for (int64_t i13 = 0; i13 < ne13; i13++)
10892
  for (int64_t i12 = 0; i12 < ne12; i12++)
10893
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -10940,15 +10940,13 @@ UseGgmlGemm1:;
10940
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10941
 
10942
  #if GGML_USE_LLAMAFILE
10943
- if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
10944
  for (int64_t i13 = 0; i13 < ne13; i13++)
10945
  for (int64_t i12 = 0; i12 < ne12; i12++)
10946
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10947
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10948
  nb01/ggml_type_size(src0->type),
10949
- (const char *)wdata + ggml_row_size(vec_dot_type,
10950
- nb12/ggml_type_size(src1->type)*i12 +
10951
- nb13/ggml_type_size(src1->type)*i13),
10952
  row_size/ggml_type_size(vec_dot_type),
10953
  (char *)dst->data + i12*nb2 + i13*nb3,
10954
  nb1/ggml_type_size(dst->type),
 
10887
  #endif
10888
 
10889
  #if GGML_USE_LLAMAFILE
10890
+ if (src1_cont) {
10891
  for (int64_t i13 = 0; i13 < ne13; i13++)
10892
  for (int64_t i12 = 0; i12 < ne12; i12++)
10893
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
 
10940
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10941
 
10942
  #if GGML_USE_LLAMAFILE
10943
+ if (src1->type != vec_dot_type) {
10944
  for (int64_t i13 = 0; i13 < ne13; i13++)
10945
  for (int64_t i12 = 0; i12 < ne12; i12++)
10946
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10947
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10948
  nb01/ggml_type_size(src0->type),
10949
+ (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
 
 
10950
  row_size/ggml_type_size(vec_dot_type),
10951
  (char *)dst->data + i12*nb2 + i13*nb3,
10952
  nb1/ggml_type_size(dst->type),