Spaces:
Sleeping
Sleeping
Justine Tunney
commited on
Commit
·
bfe2a5f
1
Parent(s):
12af87c
llamafile : improve sgemm.cpp (llama/6796)
Browse files* llamafile : improve sgemm.cpp
- Re-enable by default
- Fix issue described in #6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases
* Address review comments
* Help clang produce fma instructions
* Address review comments
ggml.c
CHANGED
|
@@ -10887,7 +10887,7 @@ static void ggml_compute_forward_mul_mat(
|
|
| 10887 |
#endif
|
| 10888 |
|
| 10889 |
#if GGML_USE_LLAMAFILE
|
| 10890 |
-
if (
|
| 10891 |
for (int64_t i13 = 0; i13 < ne13; i13++)
|
| 10892 |
for (int64_t i12 = 0; i12 < ne12; i12++)
|
| 10893 |
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
@@ -10940,15 +10940,13 @@ UseGgmlGemm1:;
|
|
| 10940 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
| 10941 |
|
| 10942 |
#if GGML_USE_LLAMAFILE
|
| 10943 |
-
if (
|
| 10944 |
for (int64_t i13 = 0; i13 < ne13; i13++)
|
| 10945 |
for (int64_t i12 = 0; i12 < ne12; i12++)
|
| 10946 |
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
| 10947 |
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
| 10948 |
nb01/ggml_type_size(src0->type),
|
| 10949 |
-
(const char *)wdata +
|
| 10950 |
-
nb12/ggml_type_size(src1->type)*i12 +
|
| 10951 |
-
nb13/ggml_type_size(src1->type)*i13),
|
| 10952 |
row_size/ggml_type_size(vec_dot_type),
|
| 10953 |
(char *)dst->data + i12*nb2 + i13*nb3,
|
| 10954 |
nb1/ggml_type_size(dst->type),
|
|
|
|
| 10887 |
#endif
|
| 10888 |
|
| 10889 |
#if GGML_USE_LLAMAFILE
|
| 10890 |
+
if (src1_cont) {
|
| 10891 |
for (int64_t i13 = 0; i13 < ne13; i13++)
|
| 10892 |
for (int64_t i12 = 0; i12 < ne12; i12++)
|
| 10893 |
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
|
|
| 10940 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
| 10941 |
|
| 10942 |
#if GGML_USE_LLAMAFILE
|
| 10943 |
+
if (src1->type != vec_dot_type) {
|
| 10944 |
for (int64_t i13 = 0; i13 < ne13; i13++)
|
| 10945 |
for (int64_t i12 = 0; i12 < ne12; i12++)
|
| 10946 |
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
| 10947 |
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
| 10948 |
nb01/ggml_type_size(src0->type),
|
| 10949 |
+
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
|
|
|
|
|
| 10950 |
row_size/ggml_type_size(vec_dot_type),
|
| 10951 |
(char *)dst->data + i12*nb2 + i13*nb3,
|
| 10952 |
nb1/ggml_type_size(dst->type),
|