Spaces:
Sleeping
Sleeping
Bizhao Shi
commited on
Commit
·
112c144
1
Parent(s):
8c2a700
CANN: Add the basic supports of Flash Attention kernel (llama/13627)
Browse files* cann: add the basic FA support
* cann: update the readme
* cann: update the FlashAttention with PSEShift
* cann: update the input parameters in FA
* cann: update the alibi with max_bias
* cann: add the constrints of softcap
* cann: update the docs CANN.md
* cann: update the docs CANN.md
* cann: fix typo of CANN.md
* cann: add some comments and update the CANN.md
* cann: update the CANN.md
* cann: update the inner precise for fusedInferAttention
* cann: update the constraints of flash_attn_ext on ggml-cann.cpp
* cann: clean the whitespace
* cann: clean the whitespace
* cann: add a new endline
- ggml/src/ggml-cann/CMakeLists.txt +0 -0
- ggml/src/ggml-cann/Doxyfile +0 -0
- ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- ggml/src/ggml-cann/acl_tensor.h +0 -0
- ggml/src/ggml-cann/aclnn_ops.cpp +330 -0
- ggml/src/ggml-cann/aclnn_ops.h +15 -0
- ggml/src/ggml-cann/common.h +0 -0
- ggml/src/ggml-cann/ggml-cann.cpp +36 -0
ggml/src/ggml-cann/CMakeLists.txt
CHANGED
|
File without changes
|
ggml/src/ggml-cann/Doxyfile
CHANGED
|
File without changes
|
ggml/src/ggml-cann/acl_tensor.cpp
CHANGED
|
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
|
| 31 |
return ACL_FLOAT;
|
| 32 |
case GGML_TYPE_F16:
|
| 33 |
return ACL_FLOAT16;
|
|
|
|
|
|
|
| 34 |
case GGML_TYPE_I8:
|
| 35 |
return ACL_INT8;
|
| 36 |
case GGML_TYPE_I16:
|
|
|
|
| 31 |
return ACL_FLOAT;
|
| 32 |
case GGML_TYPE_F16:
|
| 33 |
return ACL_FLOAT16;
|
| 34 |
+
case GGML_TYPE_BF16:
|
| 35 |
+
return ACL_BF16;
|
| 36 |
case GGML_TYPE_I8:
|
| 37 |
return ACL_INT8;
|
| 38 |
case GGML_TYPE_I16:
|
ggml/src/ggml-cann/acl_tensor.h
CHANGED
|
File without changes
|
ggml/src/ggml-cann/aclnn_ops.cpp
CHANGED
|
@@ -66,6 +66,7 @@
|
|
| 66 |
#include <aclnnop/aclnn_gt_scalar.h>
|
| 67 |
#include <aclnnop/aclnn_pow.h>
|
| 68 |
#include <aclnnop/aclnn_grouped_matmul_v2.h>
|
|
|
|
| 69 |
#include <float.h>
|
| 70 |
|
| 71 |
#include <cmath>
|
|
@@ -74,11 +75,13 @@
|
|
| 74 |
#include <vector>
|
| 75 |
|
| 76 |
#include "ggml-impl.h"
|
|
|
|
| 77 |
|
| 78 |
#define GGML_COMMON_DECL_C
|
| 79 |
|
| 80 |
#include "../ggml-common.h"
|
| 81 |
|
|
|
|
| 82 |
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
|
| 83 |
aclTensor ** acl_src1, aclTensor ** acl_dst) {
|
| 84 |
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
|
|
@@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|
| 2861 |
break;
|
| 2862 |
}
|
| 2863 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
#include <aclnnop/aclnn_gt_scalar.h>
|
| 67 |
#include <aclnnop/aclnn_pow.h>
|
| 68 |
#include <aclnnop/aclnn_grouped_matmul_v2.h>
|
| 69 |
+
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
|
| 70 |
#include <float.h>
|
| 71 |
|
| 72 |
#include <cmath>
|
|
|
|
| 75 |
#include <vector>
|
| 76 |
|
| 77 |
#include "ggml-impl.h"
|
| 78 |
+
#include "ggml.h"
|
| 79 |
|
| 80 |
#define GGML_COMMON_DECL_C
|
| 81 |
|
| 82 |
#include "../ggml-common.h"
|
| 83 |
|
| 84 |
+
|
| 85 |
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
|
| 86 |
aclTensor ** acl_src1, aclTensor ** acl_dst) {
|
| 87 |
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
|
|
|
|
| 2864 |
break;
|
| 2865 |
}
|
| 2866 |
}
|
| 2867 |
+
|
| 2868 |
+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
| 2869 |
+
|
| 2870 |
+
ggml_tensor* src0 = dst->src[0]; // q, fp32
|
| 2871 |
+
ggml_tensor* src1 = dst->src[1]; // k, fp16
|
| 2872 |
+
ggml_tensor* src2 = dst->src[2]; // v, fp16
|
| 2873 |
+
ggml_tensor* src3 = dst->src[3]; // mask, fp16
|
| 2874 |
+
|
| 2875 |
+
float maxBias = 0.0f;
|
| 2876 |
+
float scaleValue = 1.0f;
|
| 2877 |
+
float logitSoftcap = 0.0f;
|
| 2878 |
+
memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
|
| 2879 |
+
memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
|
| 2880 |
+
memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
|
| 2881 |
+
|
| 2882 |
+
if(logitSoftcap == 0.0f){
|
| 2883 |
+
size_t faElemSize = sizeof(uint16_t);
|
| 2884 |
+
auto faDataType = ACL_FLOAT16; //ACL_BF16;
|
| 2885 |
+
|
| 2886 |
+
aclTensor* acl_src0_f16_tensor = nullptr;
|
| 2887 |
+
aclTensor* acl_src1_f16_tensor = nullptr;
|
| 2888 |
+
aclTensor* acl_src2_f16_tensor = nullptr;
|
| 2889 |
+
aclTensor* acl_dst_f16_tensor = nullptr;
|
| 2890 |
+
|
| 2891 |
+
// Step 1: cast the src0 (Query) to fp16 if needed
|
| 2892 |
+
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
|
| 2893 |
+
void* src0_f16_buffer = nullptr;
|
| 2894 |
+
|
| 2895 |
+
if(ggml_cann_type_mapping(src0->type) != faDataType){
|
| 2896 |
+
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
|
| 2897 |
+
src0_f16_buffer = src0_f16_allocator.alloc(
|
| 2898 |
+
ggml_nelements(src0) * faElemSize);
|
| 2899 |
+
|
| 2900 |
+
int64_t* src0_f16_ne = src0->ne;
|
| 2901 |
+
size_t src0_f16_nb[GGML_MAX_DIMS];
|
| 2902 |
+
src0_f16_nb[0] = sizeof(uint16_t);
|
| 2903 |
+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
| 2904 |
+
src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
|
| 2905 |
+
}
|
| 2906 |
+
|
| 2907 |
+
acl_src0_f16_tensor = ggml_cann_create_tensor(
|
| 2908 |
+
src0_f16_buffer, faDataType, faElemSize,
|
| 2909 |
+
src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
|
| 2910 |
+
);
|
| 2911 |
+
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
|
| 2912 |
+
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
|
| 2913 |
+
}else{
|
| 2914 |
+
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
|
| 2915 |
+
}
|
| 2916 |
+
|
| 2917 |
+
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
|
| 2918 |
+
// and the direct output from FusedInferAttention
|
| 2919 |
+
|
| 2920 |
+
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
|
| 2921 |
+
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
|
| 2922 |
+
|
| 2923 |
+
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
| 2924 |
+
void* out_f16_buffer = out_f16_allocator.alloc(
|
| 2925 |
+
ggml_nelements(dst) * faElemSize);
|
| 2926 |
+
|
| 2927 |
+
int64_t* out_f16_ne = src0->ne;
|
| 2928 |
+
size_t out_f16_nb[GGML_MAX_DIMS];
|
| 2929 |
+
out_f16_nb[0] = faElemSize;
|
| 2930 |
+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
| 2931 |
+
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
| 2932 |
+
}
|
| 2933 |
+
|
| 2934 |
+
acl_dst_f16_tensor = ggml_cann_create_tensor(
|
| 2935 |
+
out_f16_buffer, faDataType, faElemSize,
|
| 2936 |
+
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
| 2937 |
+
);
|
| 2938 |
+
|
| 2939 |
+
// Step 3: create the PSEShift tensor if needed
|
| 2940 |
+
// this tensor is considered as mask (f16) in the llama.cpp
|
| 2941 |
+
|
| 2942 |
+
aclTensor* bcast_pse_tensor = nullptr;
|
| 2943 |
+
int64_t bcast_pse_ne[GGML_MAX_DIMS];
|
| 2944 |
+
size_t bcast_pse_nb[GGML_MAX_DIMS];
|
| 2945 |
+
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
|
| 2946 |
+
void* bcast_pse_buffer = nullptr;
|
| 2947 |
+
|
| 2948 |
+
if(src3 != nullptr){
|
| 2949 |
+
bcast_pse_buffer = bcast_pse_allocator.alloc(
|
| 2950 |
+
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
|
| 2951 |
+
|
| 2952 |
+
if(src0->ne[1] > 1){
|
| 2953 |
+
// Case 1: broadcast pse for prefill stage with multiple head
|
| 2954 |
+
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
|
| 2955 |
+
bcast_pse_ne[0] = src3->ne[0];
|
| 2956 |
+
bcast_pse_ne[1] = src3->ne[1];
|
| 2957 |
+
bcast_pse_ne[2] = src0->ne[2];
|
| 2958 |
+
bcast_pse_ne[3] = src3->ne[3];
|
| 2959 |
+
|
| 2960 |
+
bcast_pse_nb[0] = sizeof(uint16_t);
|
| 2961 |
+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
| 2962 |
+
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
|
| 2963 |
+
}
|
| 2964 |
+
|
| 2965 |
+
bcast_pse_tensor = ggml_cann_create_tensor(
|
| 2966 |
+
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
|
| 2967 |
+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
|
| 2968 |
+
|
| 2969 |
+
int64_t repeats[] = {1, src0->ne[2], 1, 1};
|
| 2970 |
+
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
|
| 2971 |
+
|
| 2972 |
+
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
|
| 2973 |
+
}else{
|
| 2974 |
+
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
|
| 2975 |
+
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
|
| 2976 |
+
size_t* trunc_pse_nb = src3->nb;
|
| 2977 |
+
|
| 2978 |
+
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
|
| 2979 |
+
src3->data, ACL_FLOAT16, sizeof(uint16_t),
|
| 2980 |
+
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
|
| 2981 |
+
|
| 2982 |
+
bcast_pse_ne[0] = src3->ne[0];
|
| 2983 |
+
bcast_pse_ne[1] = src0->ne[1];
|
| 2984 |
+
bcast_pse_ne[2] = src0->ne[2];
|
| 2985 |
+
bcast_pse_ne[3] = src3->ne[3];
|
| 2986 |
+
|
| 2987 |
+
bcast_pse_nb[0] = sizeof(uint16_t);
|
| 2988 |
+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
| 2989 |
+
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
|
| 2990 |
+
}
|
| 2991 |
+
|
| 2992 |
+
bcast_pse_tensor = ggml_cann_create_tensor(
|
| 2993 |
+
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
|
| 2994 |
+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
|
| 2995 |
+
|
| 2996 |
+
int64_t repeats[] = {1, src0->ne[2], 1, 1};
|
| 2997 |
+
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
|
| 2998 |
+
|
| 2999 |
+
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
|
| 3000 |
+
}
|
| 3001 |
+
|
| 3002 |
+
// Compute the slope if needed. Derived from ggml_cann_softmax().
|
| 3003 |
+
if(maxBias != 0.0f){
|
| 3004 |
+
// alibi
|
| 3005 |
+
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
|
| 3006 |
+
const int64_t n_head = src0->ne[2];
|
| 3007 |
+
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
|
| 3008 |
+
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
|
| 3009 |
+
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
|
| 3010 |
+
// init arange
|
| 3011 |
+
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
|
| 3012 |
+
ne2_ne3 * faElemSize);
|
| 3013 |
+
void* tmp_arange_buffer = arange_allocator.get();
|
| 3014 |
+
|
| 3015 |
+
// arange1: [1, ..., n_heads_log2_floor+1)
|
| 3016 |
+
float start = 1;
|
| 3017 |
+
float stop = n_heads_log2_floor + 1;
|
| 3018 |
+
float step = 1;
|
| 3019 |
+
int64_t n_elements_arange = n_heads_log2_floor;
|
| 3020 |
+
|
| 3021 |
+
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
|
| 3022 |
+
size_t tmp_arange1_nb[] = {faElemSize};
|
| 3023 |
+
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
|
| 3024 |
+
tmp_arange_buffer, faDataType, faElemSize,
|
| 3025 |
+
tmp_arange1_ne, tmp_arange1_nb,
|
| 3026 |
+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3027 |
+
|
| 3028 |
+
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
|
| 3029 |
+
|
| 3030 |
+
aclTensor* tmp_arange2_tensor = nullptr;
|
| 3031 |
+
if (n_heads_log2_floor < ne2_ne3) {
|
| 3032 |
+
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
|
| 3033 |
+
start = 1;
|
| 3034 |
+
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
|
| 3035 |
+
step = 2;
|
| 3036 |
+
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
|
| 3037 |
+
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
| 3038 |
+
size_t tmp_arange2_nb[] = {faElemSize};
|
| 3039 |
+
|
| 3040 |
+
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
|
| 3041 |
+
(char*)tmp_arange_buffer +
|
| 3042 |
+
n_heads_log2_floor * faElemSize,
|
| 3043 |
+
faDataType, faElemSize,
|
| 3044 |
+
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3045 |
+
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
|
| 3046 |
+
n_elements_arange);
|
| 3047 |
+
}
|
| 3048 |
+
|
| 3049 |
+
// init mk_base
|
| 3050 |
+
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
|
| 3051 |
+
ne2_ne3 * faElemSize);
|
| 3052 |
+
void* tmp_mk_base_buffer = mk_base_allocator.get();
|
| 3053 |
+
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
|
| 3054 |
+
size_t tmp_mk_base1_nb[] = {faElemSize};
|
| 3055 |
+
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
|
| 3056 |
+
tmp_mk_base_buffer, faDataType, faElemSize,
|
| 3057 |
+
tmp_mk_base1_ne, tmp_mk_base1_nb,
|
| 3058 |
+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3059 |
+
|
| 3060 |
+
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
|
| 3061 |
+
|
| 3062 |
+
aclTensor* tmp_mk_base2_tensor = nullptr;
|
| 3063 |
+
if (n_heads_log2_floor < ne2_ne3) {
|
| 3064 |
+
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
| 3065 |
+
size_t tmp_mk_base2_nb[] = {faElemSize};
|
| 3066 |
+
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
|
| 3067 |
+
(char*)tmp_mk_base_buffer +
|
| 3068 |
+
n_heads_log2_floor * faElemSize,
|
| 3069 |
+
faDataType, faElemSize,
|
| 3070 |
+
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3071 |
+
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
|
| 3072 |
+
}
|
| 3073 |
+
|
| 3074 |
+
// init mk
|
| 3075 |
+
int64_t tmp_mk_base_ne[] = {ne2_ne3};
|
| 3076 |
+
size_t tmp_mk_base_nb[] = {faElemSize};
|
| 3077 |
+
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
|
| 3078 |
+
tmp_mk_base_buffer, faDataType, faElemSize,
|
| 3079 |
+
tmp_mk_base_ne, tmp_mk_base_nb,
|
| 3080 |
+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3081 |
+
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
|
| 3082 |
+
tmp_arange_buffer, faDataType, faElemSize,
|
| 3083 |
+
tmp_mk_base_ne, tmp_mk_base_nb,
|
| 3084 |
+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
| 3085 |
+
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
|
| 3086 |
+
|
| 3087 |
+
// reshape mk
|
| 3088 |
+
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
|
| 3089 |
+
size_t tmp_mk_nb[GGML_MAX_DIMS];
|
| 3090 |
+
tmp_mk_nb[0] = faElemSize;
|
| 3091 |
+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
| 3092 |
+
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
|
| 3093 |
+
}
|
| 3094 |
+
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
|
| 3095 |
+
tmp_mk_base_buffer, faDataType, faElemSize,
|
| 3096 |
+
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
|
| 3097 |
+
ACL_FORMAT_ND);
|
| 3098 |
+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
|
| 3099 |
+
|
| 3100 |
+
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
|
| 3101 |
+
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
|
| 3102 |
+
tmp_arange_tensor, tmp_mk_tensor);
|
| 3103 |
+
}
|
| 3104 |
+
}
|
| 3105 |
+
|
| 3106 |
+
// Step 4: set the inputs for FusedInferAttention.
|
| 3107 |
+
int kvTensorNum = 1;
|
| 3108 |
+
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
|
| 3109 |
+
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
|
| 3110 |
+
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
|
| 3111 |
+
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
| 3112 |
+
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
| 3113 |
+
|
| 3114 |
+
int64_t numHeads = src0->ne[2]; // N
|
| 3115 |
+
int64_t numKeyValueHeads = src1->ne[2];
|
| 3116 |
+
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
|
| 3117 |
+
int64_t preTokens = 65535;
|
| 3118 |
+
int64_t nextTokens = 65535;
|
| 3119 |
+
char layout[5] = {'B', 'N', 'S', 'D', 0};
|
| 3120 |
+
int64_t sparseMode = 0;
|
| 3121 |
+
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
|
| 3122 |
+
int64_t blockSize = 0;
|
| 3123 |
+
int64_t antiquantMode = 0;
|
| 3124 |
+
bool softmaxLseFlag = false;
|
| 3125 |
+
int64_t keyAntiquantMode = 0;
|
| 3126 |
+
int64_t valueAntiquantMode = 0;
|
| 3127 |
+
|
| 3128 |
+
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
|
| 3129 |
+
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
|
| 3130 |
+
|
| 3131 |
+
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
|
| 3132 |
+
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
|
| 3133 |
+
bcast_pse_tensor, nullptr, // pse, mask
|
| 3134 |
+
nullptr, nullptr, // actSeqLen, actSeqLenkv
|
| 3135 |
+
nullptr, nullptr, // deqScale1, quantScale1
|
| 3136 |
+
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
|
| 3137 |
+
nullptr, nullptr, // antiquantScale, antiquantOffset
|
| 3138 |
+
nullptr, // blockTable
|
| 3139 |
+
nullptr, nullptr, // qPadSize, kvPadSize
|
| 3140 |
+
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
|
| 3141 |
+
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
|
| 3142 |
+
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
|
| 3143 |
+
numHeads, scaleValue, // heads, scaleValue
|
| 3144 |
+
preTokens, nextTokens, // preTokens, nextTokens
|
| 3145 |
+
layout, // inputLayout
|
| 3146 |
+
numKeyValueHeads, // numKVHeads
|
| 3147 |
+
sparseMode, innerPrecise, // sparseMode, innerPrecise
|
| 3148 |
+
blockSize, antiquantMode, // blockSize, antiquantMode
|
| 3149 |
+
softmaxLseFlag, // softmaxLseFlag
|
| 3150 |
+
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
|
| 3151 |
+
acl_dst_f16_tensor, // attentionOut
|
| 3152 |
+
nullptr // softmaxLse
|
| 3153 |
+
);
|
| 3154 |
+
|
| 3155 |
+
// Step 6: post-processing, permute and cast to f32
|
| 3156 |
+
|
| 3157 |
+
int64_t new_dim[] = {0, 2, 1, 3};
|
| 3158 |
+
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
| 3159 |
+
|
| 3160 |
+
if(ggml_cann_type_mapping(dst->type) != faDataType){
|
| 3161 |
+
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
|
| 3162 |
+
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
|
| 3163 |
+
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
|
| 3164 |
+
|
| 3165 |
+
int64_t* perm_out_f16_ne = dst->ne;
|
| 3166 |
+
size_t perm_out_f16_nb[GGML_MAX_DIMS];
|
| 3167 |
+
perm_out_f16_nb[0] = faElemSize;
|
| 3168 |
+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
| 3169 |
+
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
|
| 3170 |
+
}
|
| 3171 |
+
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
|
| 3172 |
+
perm_out_f16_buffer, faDataType, faElemSize,
|
| 3173 |
+
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
|
| 3174 |
+
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
|
| 3175 |
+
aclnn_cast(ctx,
|
| 3176 |
+
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
| 3177 |
+
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
|
| 3178 |
+
}else{
|
| 3179 |
+
// only need to permute
|
| 3180 |
+
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
|
| 3181 |
+
}
|
| 3182 |
+
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
| 3183 |
+
acl_src1_f16_tensor,
|
| 3184 |
+
acl_src2_f16_tensor,
|
| 3185 |
+
acl_dst_f16_tensor,
|
| 3186 |
+
acl_dst_tensor);
|
| 3187 |
+
if(src3 != nullptr){
|
| 3188 |
+
ggml_cann_release_resources(ctx, bcast_pse_tensor);
|
| 3189 |
+
}
|
| 3190 |
+
}else{
|
| 3191 |
+
GGML_ABORT("Function is not implemented.");
|
| 3192 |
+
}
|
| 3193 |
+
}
|
ggml/src/ggml-cann/aclnn_ops.h
CHANGED
|
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|
| 714 |
*/
|
| 715 |
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
| 716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
/*
|
| 718 |
* @brief A generic wrapper for ACL resources with custom deleter support.
|
| 719 |
*/
|
|
|
|
| 714 |
*/
|
| 715 |
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
| 716 |
|
| 717 |
+
/**
|
| 718 |
+
* @brief Performs the Flash Attention extended operator using the CANN backend.
|
| 719 |
+
*
|
| 720 |
+
* @details This function implements the memory-efficient Flash Attention algorithm
|
| 721 |
+
* for computing scaled dot-product attention with hardware acceleration.
|
| 722 |
+
* The result is stored in the destination tensor `dst`.
|
| 723 |
+
*
|
| 724 |
+
* This operation is accelerated using the CANN backend to improve runtime performance.
|
| 725 |
+
*
|
| 726 |
+
* @param ctx The CANN context used for operations.
|
| 727 |
+
* @param dst The destination tensor where the result will be stored.
|
| 728 |
+
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
|
| 729 |
+
*/
|
| 730 |
+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
| 731 |
+
|
| 732 |
/*
|
| 733 |
* @brief A generic wrapper for ACL resources with custom deleter support.
|
| 734 |
*/
|
ggml/src/ggml-cann/common.h
CHANGED
|
File without changes
|
ggml/src/ggml-cann/ggml-cann.cpp
CHANGED
|
@@ -36,6 +36,7 @@
|
|
| 36 |
#include "ggml-backend-impl.h"
|
| 37 |
#include "ggml-cann/aclnn_ops.h"
|
| 38 |
#include "ggml-cann/common.h"
|
|
|
|
| 39 |
|
| 40 |
#define GGML_COMMON_DECL_C
|
| 41 |
|
|
@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
| 1748 |
case GGML_OP_COUNT_EQUAL:
|
| 1749 |
ggml_cann_count_equal(ctx, dst);
|
| 1750 |
break;
|
|
|
|
|
|
|
|
|
|
| 1751 |
default:
|
| 1752 |
return false;
|
| 1753 |
}
|
|
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
| 2177 |
case GGML_OP_PAD_REFLECT_1D:
|
| 2178 |
case GGML_OP_COUNT_EQUAL:
|
| 2179 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2180 |
default:
|
| 2181 |
return false;
|
| 2182 |
}
|
|
|
|
| 36 |
#include "ggml-backend-impl.h"
|
| 37 |
#include "ggml-cann/aclnn_ops.h"
|
| 38 |
#include "ggml-cann/common.h"
|
| 39 |
+
#include "ggml.h"
|
| 40 |
|
| 41 |
#define GGML_COMMON_DECL_C
|
| 42 |
|
|
|
|
| 1749 |
case GGML_OP_COUNT_EQUAL:
|
| 1750 |
ggml_cann_count_equal(ctx, dst);
|
| 1751 |
break;
|
| 1752 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 1753 |
+
ggml_cann_flash_attn_ext(ctx, dst);
|
| 1754 |
+
break;
|
| 1755 |
default:
|
| 1756 |
return false;
|
| 1757 |
}
|
|
|
|
| 2181 |
case GGML_OP_PAD_REFLECT_1D:
|
| 2182 |
case GGML_OP_COUNT_EQUAL:
|
| 2183 |
return true;
|
| 2184 |
+
case GGML_OP_FLASH_ATTN_EXT:{
|
| 2185 |
+
// derived from [ggml-cuda.cu]
|
| 2186 |
+
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
|
| 2187 |
+
return false;
|
| 2188 |
+
}
|
| 2189 |
+
if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
|
| 2190 |
+
return false;
|
| 2191 |
+
}
|
| 2192 |
+
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
|
| 2193 |
+
return false;
|
| 2194 |
+
}
|
| 2195 |
+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 2196 |
+
// different head sizes of K and V are not supported yet
|
| 2197 |
+
return false;
|
| 2198 |
+
}
|
| 2199 |
+
if (op->src[0]->ne[0] == 192) {
|
| 2200 |
+
return false;
|
| 2201 |
+
}
|
| 2202 |
+
if (op->src[0]->ne[0] == 576) {
|
| 2203 |
+
// DeepSeek MLA
|
| 2204 |
+
return false;
|
| 2205 |
+
}
|
| 2206 |
+
if (op->src[0]->ne[3] != 1) {
|
| 2207 |
+
return false;
|
| 2208 |
+
}
|
| 2209 |
+
float logitSoftcap = 0.0f;
|
| 2210 |
+
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
|
| 2211 |
+
if(logitSoftcap != 0.0f) {
|
| 2212 |
+
return false;
|
| 2213 |
+
}
|
| 2214 |
+
return true;
|
| 2215 |
+
}
|
| 2216 |
default:
|
| 2217 |
return false;
|
| 2218 |
}
|