Spaces:
Running
Running
ggml : adapt AMX to tensor->grad removal (llama/0)
Browse files
ggml/src/ggml-amx/ggml-amx.cpp
CHANGED
|
@@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
|
|
| 317 |
const enum ggml_type type = src0->type;
|
| 318 |
const int64_t ne0 = op->ne[0];
|
| 319 |
|
| 320 |
-
bool is_training = src0->grad || src1->grad;
|
| 321 |
-
|
| 322 |
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
| 323 |
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
| 324 |
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
|
|
@@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
|
|
| 326 |
bool can_use_amx =
|
| 327 |
is_contiguous_2d(src0) && // src0 must be contiguous
|
| 328 |
is_contiguous_2d(src1) && // src1 must be contiguous
|
| 329 |
-
!is_training && // inference only
|
| 330 |
src1->type == GGML_TYPE_F32 && // src1 must be float32
|
| 331 |
has_amx_kernels && // with amx kernel impls
|
| 332 |
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|
|
|
|
| 317 |
const enum ggml_type type = src0->type;
|
| 318 |
const int64_t ne0 = op->ne[0];
|
| 319 |
|
|
|
|
|
|
|
| 320 |
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
| 321 |
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
| 322 |
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
|
|
|
|
| 324 |
bool can_use_amx =
|
| 325 |
is_contiguous_2d(src0) && // src0 must be contiguous
|
| 326 |
is_contiguous_2d(src1) && // src1 must be contiguous
|
|
|
|
| 327 |
src1->type == GGML_TYPE_F32 && // src1 must be float32
|
| 328 |
has_amx_kernels && // with amx kernel impls
|
| 329 |
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|