Bizhao Shi commited on
Commit
112c144
·
1 Parent(s): 8c2a700

CANN: Add the basic supports of Flash Attention kernel (llama/13627)

Browse files

* cann: add the basic FA support

* cann: update the readme

* cann: update the FlashAttention with PSEShift

* cann: update the input parameters in FA

* cann: update the alibi with max_bias

* cann: add the constrints of softcap

* cann: update the docs CANN.md

* cann: update the docs CANN.md

* cann: fix typo of CANN.md

* cann: add some comments and update the CANN.md

* cann: update the CANN.md

* cann: update the inner precise for fusedInferAttention

* cann: update the constraints of flash_attn_ext on ggml-cann.cpp

* cann: clean the whitespace

* cann: clean the whitespace

* cann: add a new endline

ggml/src/ggml-cann/CMakeLists.txt CHANGED
File without changes
ggml/src/ggml-cann/Doxyfile CHANGED
File without changes
ggml/src/ggml-cann/acl_tensor.cpp CHANGED
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
31
  return ACL_FLOAT;
32
  case GGML_TYPE_F16:
33
  return ACL_FLOAT16;
 
 
34
  case GGML_TYPE_I8:
35
  return ACL_INT8;
36
  case GGML_TYPE_I16:
 
31
  return ACL_FLOAT;
32
  case GGML_TYPE_F16:
33
  return ACL_FLOAT16;
34
+ case GGML_TYPE_BF16:
35
+ return ACL_BF16;
36
  case GGML_TYPE_I8:
37
  return ACL_INT8;
38
  case GGML_TYPE_I16:
ggml/src/ggml-cann/acl_tensor.h CHANGED
File without changes
ggml/src/ggml-cann/aclnn_ops.cpp CHANGED
@@ -66,6 +66,7 @@
66
  #include <aclnnop/aclnn_gt_scalar.h>
67
  #include <aclnnop/aclnn_pow.h>
68
  #include <aclnnop/aclnn_grouped_matmul_v2.h>
 
69
  #include <float.h>
70
 
71
  #include <cmath>
@@ -74,11 +75,13 @@
74
  #include <vector>
75
 
76
  #include "ggml-impl.h"
 
77
 
78
  #define GGML_COMMON_DECL_C
79
 
80
  #include "../ggml-common.h"
81
 
 
82
  void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
83
  aclTensor ** acl_src1, aclTensor ** acl_dst) {
84
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
@@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2861
  break;
2862
  }
2863
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  #include <aclnnop/aclnn_gt_scalar.h>
67
  #include <aclnnop/aclnn_pow.h>
68
  #include <aclnnop/aclnn_grouped_matmul_v2.h>
69
+ #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
70
  #include <float.h>
71
 
72
  #include <cmath>
 
75
  #include <vector>
76
 
77
  #include "ggml-impl.h"
78
+ #include "ggml.h"
79
 
80
  #define GGML_COMMON_DECL_C
81
 
82
  #include "../ggml-common.h"
83
 
84
+
85
  void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
86
  aclTensor ** acl_src1, aclTensor ** acl_dst) {
87
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
 
2864
  break;
2865
  }
2866
  }
2867
+
2868
+ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2869
+
2870
+ ggml_tensor* src0 = dst->src[0]; // q, fp32
2871
+ ggml_tensor* src1 = dst->src[1]; // k, fp16
2872
+ ggml_tensor* src2 = dst->src[2]; // v, fp16
2873
+ ggml_tensor* src3 = dst->src[3]; // mask, fp16
2874
+
2875
+ float maxBias = 0.0f;
2876
+ float scaleValue = 1.0f;
2877
+ float logitSoftcap = 0.0f;
2878
+ memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
2879
+ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
2880
+ memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
2881
+
2882
+ if(logitSoftcap == 0.0f){
2883
+ size_t faElemSize = sizeof(uint16_t);
2884
+ auto faDataType = ACL_FLOAT16; //ACL_BF16;
2885
+
2886
+ aclTensor* acl_src0_f16_tensor = nullptr;
2887
+ aclTensor* acl_src1_f16_tensor = nullptr;
2888
+ aclTensor* acl_src2_f16_tensor = nullptr;
2889
+ aclTensor* acl_dst_f16_tensor = nullptr;
2890
+
2891
+ // Step 1: cast the src0 (Query) to fp16 if needed
2892
+ ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
2893
+ void* src0_f16_buffer = nullptr;
2894
+
2895
+ if(ggml_cann_type_mapping(src0->type) != faDataType){
2896
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
2897
+ src0_f16_buffer = src0_f16_allocator.alloc(
2898
+ ggml_nelements(src0) * faElemSize);
2899
+
2900
+ int64_t* src0_f16_ne = src0->ne;
2901
+ size_t src0_f16_nb[GGML_MAX_DIMS];
2902
+ src0_f16_nb[0] = sizeof(uint16_t);
2903
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2904
+ src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
2905
+ }
2906
+
2907
+ acl_src0_f16_tensor = ggml_cann_create_tensor(
2908
+ src0_f16_buffer, faDataType, faElemSize,
2909
+ src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
2910
+ );
2911
+ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
2912
+ ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
2913
+ }else{
2914
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
2915
+ }
2916
+
2917
+ // Step 2: create the acl tensors for src1 (Key), src2 (Value),
2918
+ // and the direct output from FusedInferAttention
2919
+
2920
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
2921
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
2922
+
2923
+ ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
2924
+ void* out_f16_buffer = out_f16_allocator.alloc(
2925
+ ggml_nelements(dst) * faElemSize);
2926
+
2927
+ int64_t* out_f16_ne = src0->ne;
2928
+ size_t out_f16_nb[GGML_MAX_DIMS];
2929
+ out_f16_nb[0] = faElemSize;
2930
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2931
+ out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
2932
+ }
2933
+
2934
+ acl_dst_f16_tensor = ggml_cann_create_tensor(
2935
+ out_f16_buffer, faDataType, faElemSize,
2936
+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
2937
+ );
2938
+
2939
+ // Step 3: create the PSEShift tensor if needed
2940
+ // this tensor is considered as mask (f16) in the llama.cpp
2941
+
2942
+ aclTensor* bcast_pse_tensor = nullptr;
2943
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
2944
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
2945
+ ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
2946
+ void* bcast_pse_buffer = nullptr;
2947
+
2948
+ if(src3 != nullptr){
2949
+ bcast_pse_buffer = bcast_pse_allocator.alloc(
2950
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
2951
+
2952
+ if(src0->ne[1] > 1){
2953
+ // Case 1: broadcast pse for prefill stage with multiple head
2954
+ aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2955
+ bcast_pse_ne[0] = src3->ne[0];
2956
+ bcast_pse_ne[1] = src3->ne[1];
2957
+ bcast_pse_ne[2] = src0->ne[2];
2958
+ bcast_pse_ne[3] = src3->ne[3];
2959
+
2960
+ bcast_pse_nb[0] = sizeof(uint16_t);
2961
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2962
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2963
+ }
2964
+
2965
+ bcast_pse_tensor = ggml_cann_create_tensor(
2966
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2967
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2968
+
2969
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2970
+ aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
2971
+
2972
+ ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
2973
+ }else{
2974
+ // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
2975
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
2976
+ size_t* trunc_pse_nb = src3->nb;
2977
+
2978
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
2979
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
2980
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
2981
+
2982
+ bcast_pse_ne[0] = src3->ne[0];
2983
+ bcast_pse_ne[1] = src0->ne[1];
2984
+ bcast_pse_ne[2] = src0->ne[2];
2985
+ bcast_pse_ne[3] = src3->ne[3];
2986
+
2987
+ bcast_pse_nb[0] = sizeof(uint16_t);
2988
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2989
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2990
+ }
2991
+
2992
+ bcast_pse_tensor = ggml_cann_create_tensor(
2993
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2994
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2995
+
2996
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2997
+ aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
2998
+
2999
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3000
+ }
3001
+
3002
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
3003
+ if(maxBias != 0.0f){
3004
+ // alibi
3005
+ const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
3006
+ const int64_t n_head = src0->ne[2];
3007
+ const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
3008
+ float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
3009
+ float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
3010
+ // init arange
3011
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
3012
+ ne2_ne3 * faElemSize);
3013
+ void* tmp_arange_buffer = arange_allocator.get();
3014
+
3015
+ // arange1: [1, ..., n_heads_log2_floor+1)
3016
+ float start = 1;
3017
+ float stop = n_heads_log2_floor + 1;
3018
+ float step = 1;
3019
+ int64_t n_elements_arange = n_heads_log2_floor;
3020
+
3021
+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
3022
+ size_t tmp_arange1_nb[] = {faElemSize};
3023
+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
3024
+ tmp_arange_buffer, faDataType, faElemSize,
3025
+ tmp_arange1_ne, tmp_arange1_nb,
3026
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3027
+
3028
+ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
3029
+
3030
+ aclTensor* tmp_arange2_tensor = nullptr;
3031
+ if (n_heads_log2_floor < ne2_ne3) {
3032
+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
3033
+ start = 1;
3034
+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
3035
+ step = 2;
3036
+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
3037
+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3038
+ size_t tmp_arange2_nb[] = {faElemSize};
3039
+
3040
+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
3041
+ (char*)tmp_arange_buffer +
3042
+ n_heads_log2_floor * faElemSize,
3043
+ faDataType, faElemSize,
3044
+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3045
+ aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
3046
+ n_elements_arange);
3047
+ }
3048
+
3049
+ // init mk_base
3050
+ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
3051
+ ne2_ne3 * faElemSize);
3052
+ void* tmp_mk_base_buffer = mk_base_allocator.get();
3053
+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
3054
+ size_t tmp_mk_base1_nb[] = {faElemSize};
3055
+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
3056
+ tmp_mk_base_buffer, faDataType, faElemSize,
3057
+ tmp_mk_base1_ne, tmp_mk_base1_nb,
3058
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3059
+
3060
+ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
3061
+
3062
+ aclTensor* tmp_mk_base2_tensor = nullptr;
3063
+ if (n_heads_log2_floor < ne2_ne3) {
3064
+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3065
+ size_t tmp_mk_base2_nb[] = {faElemSize};
3066
+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
3067
+ (char*)tmp_mk_base_buffer +
3068
+ n_heads_log2_floor * faElemSize,
3069
+ faDataType, faElemSize,
3070
+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3071
+ aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
3072
+ }
3073
+
3074
+ // init mk
3075
+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
3076
+ size_t tmp_mk_base_nb[] = {faElemSize};
3077
+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
3078
+ tmp_mk_base_buffer, faDataType, faElemSize,
3079
+ tmp_mk_base_ne, tmp_mk_base_nb,
3080
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3081
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
3082
+ tmp_arange_buffer, faDataType, faElemSize,
3083
+ tmp_mk_base_ne, tmp_mk_base_nb,
3084
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3085
+ aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
3086
+
3087
+ // reshape mk
3088
+ int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
3089
+ size_t tmp_mk_nb[GGML_MAX_DIMS];
3090
+ tmp_mk_nb[0] = faElemSize;
3091
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3092
+ tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
3093
+ }
3094
+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
3095
+ tmp_mk_base_buffer, faDataType, faElemSize,
3096
+ tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
3097
+ ACL_FORMAT_ND);
3098
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
3099
+
3100
+ ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
3101
+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
3102
+ tmp_arange_tensor, tmp_mk_tensor);
3103
+ }
3104
+ }
3105
+
3106
+ // Step 4: set the inputs for FusedInferAttention.
3107
+ int kvTensorNum = 1;
3108
+ aclTensor* acl_q_tensor = acl_src0_f16_tensor;
3109
+ aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
3110
+ aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
3111
+ auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
3112
+ auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
3113
+
3114
+ int64_t numHeads = src0->ne[2]; // N
3115
+ int64_t numKeyValueHeads = src1->ne[2];
3116
+ // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
3117
+ int64_t preTokens = 65535;
3118
+ int64_t nextTokens = 65535;
3119
+ char layout[5] = {'B', 'N', 'S', 'D', 0};
3120
+ int64_t sparseMode = 0;
3121
+ int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
3122
+ int64_t blockSize = 0;
3123
+ int64_t antiquantMode = 0;
3124
+ bool softmaxLseFlag = false;
3125
+ int64_t keyAntiquantMode = 0;
3126
+ int64_t valueAntiquantMode = 0;
3127
+
3128
+ // Step 5: launch the FusedInferAttentionScoreV2 kernel.
3129
+ // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3130
+
3131
+ GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
3132
+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3133
+ bcast_pse_tensor, nullptr, // pse, mask
3134
+ nullptr, nullptr, // actSeqLen, actSeqLenkv
3135
+ nullptr, nullptr, // deqScale1, quantScale1
3136
+ nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
3137
+ nullptr, nullptr, // antiquantScale, antiquantOffset
3138
+ nullptr, // blockTable
3139
+ nullptr, nullptr, // qPadSize, kvPadSize
3140
+ nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
3141
+ nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
3142
+ nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
3143
+ numHeads, scaleValue, // heads, scaleValue
3144
+ preTokens, nextTokens, // preTokens, nextTokens
3145
+ layout, // inputLayout
3146
+ numKeyValueHeads, // numKVHeads
3147
+ sparseMode, innerPrecise, // sparseMode, innerPrecise
3148
+ blockSize, antiquantMode, // blockSize, antiquantMode
3149
+ softmaxLseFlag, // softmaxLseFlag
3150
+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3151
+ acl_dst_f16_tensor, // attentionOut
3152
+ nullptr // softmaxLse
3153
+ );
3154
+
3155
+ // Step 6: post-processing, permute and cast to f32
3156
+
3157
+ int64_t new_dim[] = {0, 2, 1, 3};
3158
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3159
+
3160
+ if(ggml_cann_type_mapping(dst->type) != faDataType){
3161
+ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3162
+ perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3163
+ void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3164
+
3165
+ int64_t* perm_out_f16_ne = dst->ne;
3166
+ size_t perm_out_f16_nb[GGML_MAX_DIMS];
3167
+ perm_out_f16_nb[0] = faElemSize;
3168
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
3169
+ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3170
+ }
3171
+ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3172
+ perm_out_f16_buffer, faDataType, faElemSize,
3173
+ perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3174
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3175
+ aclnn_cast(ctx,
3176
+ acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3177
+ ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3178
+ }else{
3179
+ // only need to permute
3180
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3181
+ }
3182
+ ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
3183
+ acl_src1_f16_tensor,
3184
+ acl_src2_f16_tensor,
3185
+ acl_dst_f16_tensor,
3186
+ acl_dst_tensor);
3187
+ if(src3 != nullptr){
3188
+ ggml_cann_release_resources(ctx, bcast_pse_tensor);
3189
+ }
3190
+ }else{
3191
+ GGML_ABORT("Function is not implemented.");
3192
+ }
3193
+ }
ggml/src/ggml-cann/aclnn_ops.h CHANGED
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
714
  */
715
  void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
  /*
718
  * @brief A generic wrapper for ACL resources with custom deleter support.
719
  */
 
714
  */
715
  void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716
 
717
+ /**
718
+ * @brief Performs the Flash Attention extended operator using the CANN backend.
719
+ *
720
+ * @details This function implements the memory-efficient Flash Attention algorithm
721
+ * for computing scaled dot-product attention with hardware acceleration.
722
+ * The result is stored in the destination tensor `dst`.
723
+ *
724
+ * This operation is accelerated using the CANN backend to improve runtime performance.
725
+ *
726
+ * @param ctx The CANN context used for operations.
727
+ * @param dst The destination tensor where the result will be stored.
728
+ * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
729
+ */
730
+ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
731
+
732
  /*
733
  * @brief A generic wrapper for ACL resources with custom deleter support.
734
  */
ggml/src/ggml-cann/common.h CHANGED
File without changes
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -36,6 +36,7 @@
36
  #include "ggml-backend-impl.h"
37
  #include "ggml-cann/aclnn_ops.h"
38
  #include "ggml-cann/common.h"
 
39
 
40
  #define GGML_COMMON_DECL_C
41
 
@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1748
  case GGML_OP_COUNT_EQUAL:
1749
  ggml_cann_count_equal(ctx, dst);
1750
  break;
 
 
 
1751
  default:
1752
  return false;
1753
  }
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2177
  case GGML_OP_PAD_REFLECT_1D:
2178
  case GGML_OP_COUNT_EQUAL:
2179
  return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2180
  default:
2181
  return false;
2182
  }
 
36
  #include "ggml-backend-impl.h"
37
  #include "ggml-cann/aclnn_ops.h"
38
  #include "ggml-cann/common.h"
39
+ #include "ggml.h"
40
 
41
  #define GGML_COMMON_DECL_C
42
 
 
1749
  case GGML_OP_COUNT_EQUAL:
1750
  ggml_cann_count_equal(ctx, dst);
1751
  break;
1752
+ case GGML_OP_FLASH_ATTN_EXT:
1753
+ ggml_cann_flash_attn_ext(ctx, dst);
1754
+ break;
1755
  default:
1756
  return false;
1757
  }
 
2181
  case GGML_OP_PAD_REFLECT_1D:
2182
  case GGML_OP_COUNT_EQUAL:
2183
  return true;
2184
+ case GGML_OP_FLASH_ATTN_EXT:{
2185
+ // derived from [ggml-cuda.cu]
2186
+ if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2187
+ return false;
2188
+ }
2189
+ if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2190
+ return false;
2191
+ }
2192
+ if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2193
+ return false;
2194
+ }
2195
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2196
+ // different head sizes of K and V are not supported yet
2197
+ return false;
2198
+ }
2199
+ if (op->src[0]->ne[0] == 192) {
2200
+ return false;
2201
+ }
2202
+ if (op->src[0]->ne[0] == 576) {
2203
+ // DeepSeek MLA
2204
+ return false;
2205
+ }
2206
+ if (op->src[0]->ne[3] != 1) {
2207
+ return false;
2208
+ }
2209
+ float logitSoftcap = 0.0f;
2210
+ memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
2211
+ if(logitSoftcap != 0.0f) {
2212
+ return false;
2213
+ }
2214
+ return true;
2215
+ }
2216
  default:
2217
  return false;
2218
  }