ggerganov commited on
Commit
8a67e9f
·
1 Parent(s): 80d6ec0

ggml : adapt AMX to tensor->grad removal (llama/0)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-amx/ggml-amx.cpp +0 -3
ggml/src/ggml-amx/ggml-amx.cpp CHANGED
@@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
317
  const enum ggml_type type = src0->type;
318
  const int64_t ne0 = op->ne[0];
319
 
320
- bool is_training = src0->grad || src1->grad;
321
-
322
  // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
323
  // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
324
  bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
@@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
326
  bool can_use_amx =
327
  is_contiguous_2d(src0) && // src0 must be contiguous
328
  is_contiguous_2d(src1) && // src1 must be contiguous
329
- !is_training && // inference only
330
  src1->type == GGML_TYPE_F32 && // src1 must be float32
331
  has_amx_kernels && // with amx kernel impls
332
  ne0 % (TILE_N * 2) == 0; // out_features is 32x
 
317
  const enum ggml_type type = src0->type;
318
  const int64_t ne0 = op->ne[0];
319
 
 
 
320
  // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
321
  // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
322
  bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
 
324
  bool can_use_amx =
325
  is_contiguous_2d(src0) && // src0 must be contiguous
326
  is_contiguous_2d(src1) && // src1 must be contiguous
 
327
  src1->type == GGML_TYPE_F32 && // src1 must be float32
328
  has_amx_kernels && // with amx kernel impls
329
  ne0 % (TILE_N * 2) == 0; // out_features is 32x