Spaces:
Sleeping
Sleeping
David Huang
commited on
Commit
·
a027c1d
1
Parent(s):
dbc0180
HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (llama/12032)
Browse filesAdds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16
- ggml/CMakeLists.txt +1 -0
- ggml/src/ggml-cuda/common.cuh +13 -2
- ggml/src/ggml-cuda/fattn-common.cuh +38 -28
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +73 -63
- ggml/src/ggml-cuda/fattn.cu +10 -2
- ggml/src/ggml-hip/CMakeLists.txt +10 -0
ggml/CMakeLists.txt
CHANGED
|
@@ -162,6 +162,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
|
|
| 162 |
option(GGML_HIP "ggml: use HIP" OFF)
|
| 163 |
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
| 164 |
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
|
|
|
| 165 |
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
| 166 |
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
| 167 |
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
|
|
|
| 162 |
option(GGML_HIP "ggml: use HIP" OFF)
|
| 163 |
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
| 164 |
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
| 165 |
+
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
| 166 |
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
| 167 |
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
| 168 |
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -62,6 +62,7 @@
|
|
| 62 |
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
| 63 |
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
| 64 |
|
|
|
|
| 65 |
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
| 66 |
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
| 67 |
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
|
@@ -196,6 +197,10 @@ typedef float2 dfloat2;
|
|
| 196 |
#define FP16_MMA_AVAILABLE
|
| 197 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 200 |
#define NEW_MMA_AVAILABLE
|
| 201 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
@@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
|
|
| 223 |
|
| 224 |
// Any FP16 tensor core instructions are available for ggml code.
|
| 225 |
static bool fp16_mma_available(const int cc) {
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
}
|
| 228 |
|
| 229 |
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
| 230 |
static bool fp16_mma_hardware_available(const int cc) {
|
| 231 |
-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA
|
|
|
|
| 232 |
}
|
| 233 |
|
| 234 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
|
|
|
| 62 |
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
| 63 |
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
| 64 |
|
| 65 |
+
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
|
| 66 |
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
| 67 |
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
| 68 |
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
|
|
|
| 197 |
#define FP16_MMA_AVAILABLE
|
| 198 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 199 |
|
| 200 |
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
|
| 201 |
+
#define FP16_MMA_AVAILABLE
|
| 202 |
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
|
| 203 |
+
|
| 204 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 205 |
#define NEW_MMA_AVAILABLE
|
| 206 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
|
|
| 228 |
|
| 229 |
// Any FP16 tensor core instructions are available for ggml code.
|
| 230 |
static bool fp16_mma_available(const int cc) {
|
| 231 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
| 232 |
+
return false;
|
| 233 |
+
#else
|
| 234 |
+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
|
| 235 |
+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
| 236 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
| 237 |
}
|
| 238 |
|
| 239 |
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
| 240 |
static bool fp16_mma_hardware_available(const int cc) {
|
| 241 |
+
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
|
| 242 |
+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
| 243 |
}
|
| 244 |
|
| 245 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 57 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 58 |
|
| 59 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
|
|
|
| 60 |
GGML_UNUSED(Q_v);
|
| 61 |
|
| 62 |
T sum = 0.0f;
|
| 63 |
|
| 64 |
#pragma unroll
|
| 65 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 +=
|
| 66 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 67 |
|
| 68 |
const int ib = k_KQ / QI8_1;
|
|
@@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 70 |
const int shift = k_KQ & (QI8_1/2);
|
| 71 |
|
| 72 |
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 73 |
-
const int u = Q_q8[k_KQ_0/
|
| 74 |
|
| 75 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 76 |
|
|
@@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 78 |
if (std::is_same<T, half>::value) {
|
| 79 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 80 |
|
| 81 |
-
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/
|
| 82 |
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
| 83 |
} else
|
| 84 |
#endif // FP16_AVAILABLE
|
| 85 |
{
|
| 86 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 87 |
|
| 88 |
-
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/
|
| 89 |
}
|
| 90 |
}
|
| 91 |
|
|
@@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 97 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 98 |
|
| 99 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
|
|
|
| 100 |
GGML_UNUSED(Q_v);
|
| 101 |
|
| 102 |
T sum = 0.0f;
|
| 103 |
|
| 104 |
#pragma unroll
|
| 105 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 +=
|
| 106 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 107 |
|
| 108 |
const int ib = k_KQ / QI8_1;
|
|
@@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 110 |
const int shift = k_KQ & (QI8_1/2);
|
| 111 |
|
| 112 |
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 113 |
-
const int u = Q_q8[k_KQ_0/
|
| 114 |
|
| 115 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 116 |
|
|
@@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 118 |
if (std::is_same<T, half>::value) {
|
| 119 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 120 |
|
| 121 |
-
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/
|
| 122 |
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
| 123 |
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
| 124 |
} else
|
|
@@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 126 |
{
|
| 127 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 128 |
|
| 129 |
-
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/
|
| 130 |
-
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/
|
| 131 |
|
| 132 |
sum += (T) (sumid4d8 + m4s8scaled);
|
| 133 |
}
|
|
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 141 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 142 |
|
| 143 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
|
|
|
| 144 |
GGML_UNUSED(Q_v);
|
| 145 |
|
| 146 |
T sum = 0.0f;
|
| 147 |
|
| 148 |
#pragma unroll
|
| 149 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 +=
|
| 150 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 151 |
|
| 152 |
const int ib = k_KQ / QI8_1;
|
|
@@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 161 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
| 162 |
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
| 163 |
|
| 164 |
-
const int u = Q_q8[k_KQ_0/
|
| 165 |
|
| 166 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 167 |
|
|
@@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 169 |
if (std::is_same<T, half>::value) {
|
| 170 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 171 |
|
| 172 |
-
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/
|
| 173 |
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
| 174 |
} else
|
| 175 |
#endif // FP16_AVAILABLE
|
| 176 |
{
|
| 177 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 178 |
|
| 179 |
-
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/
|
| 180 |
}
|
| 181 |
}
|
| 182 |
|
|
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 188 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 189 |
|
| 190 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
|
|
|
| 191 |
GGML_UNUSED(Q_v);
|
| 192 |
|
| 193 |
T sum = 0.0f;
|
| 194 |
|
| 195 |
#pragma unroll
|
| 196 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 +=
|
| 197 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 198 |
|
| 199 |
const int ib = k_KQ / QI8_1;
|
|
@@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 208 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
| 209 |
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
| 210 |
|
| 211 |
-
const int u = Q_q8[k_KQ_0/
|
| 212 |
|
| 213 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 214 |
|
|
@@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 216 |
if (std::is_same<T, half>::value) {
|
| 217 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 218 |
|
| 219 |
-
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/
|
| 220 |
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
| 221 |
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
| 222 |
} else
|
|
@@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 224 |
{
|
| 225 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 226 |
|
| 227 |
-
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/
|
| 228 |
-
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/
|
| 229 |
|
| 230 |
sum += (T) (sumid5d8 + m5s8scaled);
|
| 231 |
}
|
|
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
| 239 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 240 |
|
| 241 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
|
|
|
| 242 |
GGML_UNUSED(Q_v);
|
| 243 |
|
| 244 |
T sum = 0.0f;
|
| 245 |
|
| 246 |
#pragma unroll
|
| 247 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 +=
|
| 248 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 249 |
|
| 250 |
const int ib = k_KQ / QI8_0;
|
|
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
| 255 |
T Q_d;
|
| 256 |
if (std::is_same<T, half>::value) {
|
| 257 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 258 |
-
Q_d = __low2half(Q_ds[k_KQ_0/
|
| 259 |
} else {
|
| 260 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 261 |
-
Q_d = Q_ds[k_KQ_0/
|
| 262 |
}
|
| 263 |
|
| 264 |
-
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/
|
| 265 |
}
|
| 266 |
|
| 267 |
return sum;
|
|
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
|
| 272 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
| 273 |
|
| 274 |
const half2 * K_h2 = (const half2 *) K_c;
|
|
|
|
| 275 |
GGML_UNUSED(Q_q8);
|
| 276 |
GGML_UNUSED(Q_ds_v);
|
| 277 |
|
|
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
|
| 282 |
half2 sum2 = make_half2(0.0f, 0.0f);
|
| 283 |
|
| 284 |
#pragma unroll
|
| 285 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 +=
|
| 286 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 287 |
|
| 288 |
const half2 K_ik = K_h2[k_KQ];
|
| 289 |
-
sum2 += K_ik * Q_h2[k_KQ_0/
|
| 290 |
}
|
| 291 |
|
| 292 |
return __low2half(sum2) + __high2half(sum2);
|
|
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
|
| 298 |
float sum = 0.0f;
|
| 299 |
|
| 300 |
#pragma unroll
|
| 301 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 +=
|
| 302 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 303 |
|
| 304 |
const half2 K_ik = K_h2[k_KQ];
|
| 305 |
-
sum += __low2float(K_ik) * Q_f2[k_KQ_0/
|
| 306 |
-
sum += __high2float(K_ik) * Q_f2[k_KQ_0/
|
| 307 |
}
|
| 308 |
|
| 309 |
return sum;
|
|
@@ -698,6 +704,8 @@ void launch_fattn(
|
|
| 698 |
|
| 699 |
GGML_ASSERT(Q->ne[3] == 1);
|
| 700 |
|
|
|
|
|
|
|
| 701 |
ggml_cuda_pool & pool = ctx.pool();
|
| 702 |
cudaStream_t main_stream = ctx.stream();
|
| 703 |
const int id = ggml_cuda_get_device();
|
|
@@ -750,7 +758,7 @@ void launch_fattn(
|
|
| 750 |
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
| 751 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 752 |
|
| 753 |
-
const dim3 block_dim(
|
| 754 |
dim3 blocks_num;
|
| 755 |
if (parallel_blocks == 0) {
|
| 756 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
|
@@ -796,6 +804,8 @@ void launch_fattn(
|
|
| 796 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 797 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 798 |
|
|
|
|
|
|
|
| 799 |
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
| 800 |
(const char *) Q->data,
|
| 801 |
K_data,
|
|
|
|
| 57 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 58 |
|
| 59 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
| 60 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 61 |
GGML_UNUSED(Q_v);
|
| 62 |
|
| 63 |
T sum = 0.0f;
|
| 64 |
|
| 65 |
#pragma unroll
|
| 66 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
| 67 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 68 |
|
| 69 |
const int ib = k_KQ / QI8_1;
|
|
|
|
| 71 |
const int shift = k_KQ & (QI8_1/2);
|
| 72 |
|
| 73 |
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 74 |
+
const int u = Q_q8[k_KQ_0/warp_size];
|
| 75 |
|
| 76 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 77 |
|
|
|
|
| 79 |
if (std::is_same<T, half>::value) {
|
| 80 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 81 |
|
| 82 |
+
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
| 83 |
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
| 84 |
} else
|
| 85 |
#endif // FP16_AVAILABLE
|
| 86 |
{
|
| 87 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 88 |
|
| 89 |
+
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
| 90 |
}
|
| 91 |
}
|
| 92 |
|
|
|
|
| 98 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 99 |
|
| 100 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
| 101 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 102 |
GGML_UNUSED(Q_v);
|
| 103 |
|
| 104 |
T sum = 0.0f;
|
| 105 |
|
| 106 |
#pragma unroll
|
| 107 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
| 108 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 109 |
|
| 110 |
const int ib = k_KQ / QI8_1;
|
|
|
|
| 112 |
const int shift = k_KQ & (QI8_1/2);
|
| 113 |
|
| 114 |
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 115 |
+
const int u = Q_q8[k_KQ_0/warp_size];
|
| 116 |
|
| 117 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 118 |
|
|
|
|
| 120 |
if (std::is_same<T, half>::value) {
|
| 121 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 122 |
|
| 123 |
+
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
| 124 |
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
| 125 |
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
| 126 |
} else
|
|
|
|
| 128 |
{
|
| 129 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 130 |
|
| 131 |
+
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
| 132 |
+
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
| 133 |
|
| 134 |
sum += (T) (sumid4d8 + m4s8scaled);
|
| 135 |
}
|
|
|
|
| 143 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 144 |
|
| 145 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
| 146 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 147 |
GGML_UNUSED(Q_v);
|
| 148 |
|
| 149 |
T sum = 0.0f;
|
| 150 |
|
| 151 |
#pragma unroll
|
| 152 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
| 153 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 154 |
|
| 155 |
const int ib = k_KQ / QI8_1;
|
|
|
|
| 164 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
| 165 |
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
| 166 |
|
| 167 |
+
const int u = Q_q8[k_KQ_0/warp_size];
|
| 168 |
|
| 169 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 170 |
|
|
|
|
| 172 |
if (std::is_same<T, half>::value) {
|
| 173 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 174 |
|
| 175 |
+
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
| 176 |
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
| 177 |
} else
|
| 178 |
#endif // FP16_AVAILABLE
|
| 179 |
{
|
| 180 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 181 |
|
| 182 |
+
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
| 183 |
}
|
| 184 |
}
|
| 185 |
|
|
|
|
| 191 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 192 |
|
| 193 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
| 194 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 195 |
GGML_UNUSED(Q_v);
|
| 196 |
|
| 197 |
T sum = 0.0f;
|
| 198 |
|
| 199 |
#pragma unroll
|
| 200 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
| 201 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 202 |
|
| 203 |
const int ib = k_KQ / QI8_1;
|
|
|
|
| 212 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
| 213 |
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
| 214 |
|
| 215 |
+
const int u = Q_q8[k_KQ_0/warp_size];
|
| 216 |
|
| 217 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
| 218 |
|
|
|
|
| 220 |
if (std::is_same<T, half>::value) {
|
| 221 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 222 |
|
| 223 |
+
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
| 224 |
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
| 225 |
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
| 226 |
} else
|
|
|
|
| 228 |
{
|
| 229 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 230 |
|
| 231 |
+
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
| 232 |
+
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
| 233 |
|
| 234 |
sum += (T) (sumid5d8 + m5s8scaled);
|
| 235 |
}
|
|
|
|
| 243 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 244 |
|
| 245 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
| 246 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 247 |
GGML_UNUSED(Q_v);
|
| 248 |
|
| 249 |
T sum = 0.0f;
|
| 250 |
|
| 251 |
#pragma unroll
|
| 252 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
| 253 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 254 |
|
| 255 |
const int ib = k_KQ / QI8_0;
|
|
|
|
| 260 |
T Q_d;
|
| 261 |
if (std::is_same<T, half>::value) {
|
| 262 |
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
| 263 |
+
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
|
| 264 |
} else {
|
| 265 |
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
| 266 |
+
Q_d = Q_ds[k_KQ_0/warp_size].x;
|
| 267 |
}
|
| 268 |
|
| 269 |
+
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
|
| 270 |
}
|
| 271 |
|
| 272 |
return sum;
|
|
|
|
| 277 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
| 278 |
|
| 279 |
const half2 * K_h2 = (const half2 *) K_c;
|
| 280 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 281 |
GGML_UNUSED(Q_q8);
|
| 282 |
GGML_UNUSED(Q_ds_v);
|
| 283 |
|
|
|
|
| 288 |
half2 sum2 = make_half2(0.0f, 0.0f);
|
| 289 |
|
| 290 |
#pragma unroll
|
| 291 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
| 292 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 293 |
|
| 294 |
const half2 K_ik = K_h2[k_KQ];
|
| 295 |
+
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
|
| 296 |
}
|
| 297 |
|
| 298 |
return __low2half(sum2) + __high2half(sum2);
|
|
|
|
| 304 |
float sum = 0.0f;
|
| 305 |
|
| 306 |
#pragma unroll
|
| 307 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
| 308 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 309 |
|
| 310 |
const half2 K_ik = K_h2[k_KQ];
|
| 311 |
+
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
|
| 312 |
+
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
|
| 313 |
}
|
| 314 |
|
| 315 |
return sum;
|
|
|
|
| 704 |
|
| 705 |
GGML_ASSERT(Q->ne[3] == 1);
|
| 706 |
|
| 707 |
+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
| 708 |
+
|
| 709 |
ggml_cuda_pool & pool = ctx.pool();
|
| 710 |
cudaStream_t main_stream = ctx.stream();
|
| 711 |
const int id = ggml_cuda_get_device();
|
|
|
|
| 758 |
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
| 759 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 760 |
|
| 761 |
+
const dim3 block_dim(warp_size, nwarps, 1);
|
| 762 |
dim3 blocks_num;
|
| 763 |
if (parallel_blocks == 0) {
|
| 764 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
|
|
|
| 804 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 805 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 806 |
|
| 807 |
+
GGML_ASSERT(block_dim.x % warp_size == 0);
|
| 808 |
+
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
|
| 809 |
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
| 810 |
(const char *) Q->data,
|
| 811 |
K_data,
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -7,14 +7,19 @@
|
|
| 7 |
#include "fattn-wmma-f16.cuh"
|
| 8 |
|
| 9 |
#ifdef FP16_MMA_AVAILABLE
|
|
|
|
| 10 |
#include <mma.h>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
#endif // FP16_MMA_AVAILABLE
|
| 12 |
|
| 13 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 14 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 15 |
-
|
| 16 |
-
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 17 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 18 |
static __global__ void flash_attn_ext_f16(
|
| 19 |
const char * __restrict__ Q,
|
| 20 |
const char * __restrict__ K,
|
|
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 51 |
const int ne1,
|
| 52 |
const int ne2,
|
| 53 |
const int ne3) {
|
| 54 |
-
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
| 55 |
// Skip unused kernel variants for faster compilation:
|
| 56 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 57 |
NO_DEVICE_CODE;
|
|
@@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 60 |
|
| 61 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 62 |
|
|
|
|
|
|
|
| 63 |
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
| 64 |
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 65 |
|
|
@@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16(
|
|
| 68 |
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
| 69 |
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
| 70 |
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
| 71 |
-
typedef
|
| 72 |
-
typedef
|
| 73 |
-
typedef
|
| 74 |
-
typedef
|
| 75 |
-
typedef
|
| 76 |
|
| 77 |
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
| 78 |
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
|
@@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 132 |
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 133 |
const int j = j0 + threadIdx.y;
|
| 134 |
#pragma unroll
|
| 135 |
-
for (int i0 = 0; i0 < D/2; i0 +=
|
| 136 |
const int i = i0 + threadIdx.x;
|
| 137 |
-
if (i0 +
|
| 138 |
break;
|
| 139 |
}
|
| 140 |
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
|
@@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 146 |
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 147 |
const int j = j0 + threadIdx.y;
|
| 148 |
#pragma unroll
|
| 149 |
-
for (int i0 = 0; i0 < D; i0 +=
|
| 150 |
const int i = i0 + threadIdx.x;
|
| 151 |
-
if (i0 +
|
| 152 |
break;
|
| 153 |
}
|
| 154 |
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
|
@@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 162 |
for (int i0 = 0; i0 < D; i0 += 16) {
|
| 163 |
#pragma unroll
|
| 164 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 165 |
-
|
| 166 |
}
|
| 167 |
}
|
| 168 |
|
|
@@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16(
|
|
| 176 |
frag_c_KQ KQ_c[ncols/frag_n];
|
| 177 |
#pragma unroll
|
| 178 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 179 |
-
|
| 180 |
}
|
| 181 |
#pragma unroll
|
| 182 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 183 |
frag_a_K K_a;
|
| 184 |
-
|
| 185 |
#pragma unroll
|
| 186 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 187 |
-
|
| 188 |
}
|
| 189 |
}
|
| 190 |
#pragma unroll
|
| 191 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 192 |
-
|
| 193 |
}
|
| 194 |
}
|
| 195 |
|
|
@@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16(
|
|
| 202 |
const int j = j0 + threadIdx.y;
|
| 203 |
|
| 204 |
if (std::is_same<KQ_acc_t, float>::value) {
|
| 205 |
-
float KQ_f_tmp[FATTN_KQ_STRIDE /
|
| 206 |
#pragma unroll
|
| 207 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 +=
|
| 208 |
const int k = k0 + threadIdx.x;
|
| 209 |
|
| 210 |
-
KQ_f_tmp[k0/
|
| 211 |
|
| 212 |
if (use_logit_softcap) {
|
| 213 |
-
KQ_f_tmp[k0/
|
| 214 |
}
|
| 215 |
}
|
| 216 |
|
| 217 |
float KQ_max_new = KQ_max_f[j0/nwarps];
|
| 218 |
#pragma unroll
|
| 219 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 +=
|
| 220 |
const int k = k0 + threadIdx.x;
|
| 221 |
|
| 222 |
-
KQ_f_tmp[k0/
|
| 223 |
-
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/
|
| 224 |
}
|
| 225 |
-
KQ_max_new = warp_reduce_max(KQ_max_new);
|
| 226 |
|
| 227 |
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
| 228 |
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
|
@@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16(
|
|
| 233 |
|
| 234 |
float KQ_rowsum_add = 0.0f;
|
| 235 |
#pragma unroll
|
| 236 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 +=
|
| 237 |
const int k = k0 + threadIdx.x;
|
| 238 |
|
| 239 |
-
const float diff = KQ_f_tmp[k0/
|
| 240 |
-
KQ_f_tmp[k0/
|
| 241 |
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 242 |
-
KQ_f_tmp[k0/
|
| 243 |
}
|
| 244 |
-
KQ_rowsum_add += KQ_f_tmp[k0/
|
| 245 |
-
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/
|
| 246 |
}
|
| 247 |
-
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 248 |
|
| 249 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 250 |
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
| 251 |
} else {
|
| 252 |
-
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*
|
| 253 |
#pragma unroll
|
| 254 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 +=
|
| 255 |
const int k = k0 + threadIdx.x;
|
| 256 |
|
| 257 |
-
KQ2_tmp[k0/
|
| 258 |
|
| 259 |
if (use_logit_softcap) {
|
| 260 |
// There is no dedicated tangens hyperbolicus function for half2.
|
| 261 |
-
KQ2_tmp[k0/
|
| 262 |
-
KQ2_tmp[k0/
|
| 263 |
-
/(KQ2_tmp[k0/
|
| 264 |
|
| 265 |
-
KQ2_tmp[k0/
|
| 266 |
}
|
| 267 |
}
|
| 268 |
|
| 269 |
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
| 270 |
#pragma unroll
|
| 271 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 +=
|
| 272 |
const int k = k0 + threadIdx.x;
|
| 273 |
|
| 274 |
-
KQ2_tmp[k0/
|
| 275 |
-
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/
|
| 276 |
}
|
| 277 |
-
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
| 278 |
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
| 279 |
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
| 280 |
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
|
@@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16(
|
|
| 283 |
|
| 284 |
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
| 285 |
#pragma unroll
|
| 286 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 +=
|
| 287 |
const int k = k0 + threadIdx.x;
|
| 288 |
|
| 289 |
-
const half2 diff = KQ2_tmp[k0/
|
| 290 |
-
KQ2_tmp[k0/
|
| 291 |
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 292 |
-
*((uint32_t *) &KQ2_tmp[k0/
|
| 293 |
-
KQ_rowsum_add += KQ2_tmp[k0/
|
| 294 |
-
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/
|
| 295 |
}
|
| 296 |
-
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 297 |
|
| 298 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 299 |
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
|
@@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 308 |
#pragma unroll
|
| 309 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 310 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 311 |
-
|
| 312 |
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
| 313 |
KQ + j0*(kqar*kqs_padded) + k,
|
| 314 |
kqar*kqs_padded);
|
|
@@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 320 |
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
| 321 |
#pragma unroll
|
| 322 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 323 |
-
|
| 324 |
}
|
| 325 |
|
| 326 |
#pragma unroll
|
|
@@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 328 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 329 |
|
| 330 |
frag_a_V v_a;
|
| 331 |
-
|
| 332 |
#pragma unroll
|
| 333 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 334 |
-
|
| 335 |
}
|
| 336 |
}
|
| 337 |
}
|
|
@@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 343 |
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
| 344 |
#pragma unroll
|
| 345 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 346 |
-
|
| 347 |
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
| 348 |
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
| 349 |
-
D_padded,
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|
|
@@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 364 |
}
|
| 365 |
|
| 366 |
#pragma unroll
|
| 367 |
-
for (int i0 = 0; i0 < D/2; i0 +=
|
| 368 |
const int i = i0 + threadIdx.x;
|
| 369 |
-
if (i0 +
|
| 370 |
break;
|
| 371 |
}
|
| 372 |
|
|
@@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 398 |
}
|
| 399 |
|
| 400 |
#pragma unroll
|
| 401 |
-
for (int i0 = 0; i0 < D; i0 +=
|
| 402 |
const int i = i0 + threadIdx.x;
|
| 403 |
-
if (i0 +
|
| 404 |
break;
|
| 405 |
}
|
| 406 |
float dst_val = VKQ[j_VKQ*D_padded + i];
|
|
@@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 425 |
}
|
| 426 |
#else
|
| 427 |
NO_DEVICE_CODE;
|
| 428 |
-
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
| 429 |
}
|
| 430 |
|
| 431 |
constexpr int get_max_power_of_2(int x) {
|
|
@@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 515 |
const ggml_tensor * Q = dst->src[0];
|
| 516 |
|
| 517 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
|
|
|
| 518 |
|
| 519 |
if (prec != GGML_PREC_DEFAULT) {
|
| 520 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
|
@@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 571 |
return;
|
| 572 |
}
|
| 573 |
|
| 574 |
-
|
|
|
|
| 575 |
constexpr int cols_per_block = 8;
|
| 576 |
switch (Q->ne[0]) {
|
| 577 |
case 64:
|
|
@@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 592 |
}
|
| 593 |
return;
|
| 594 |
}
|
|
|
|
| 595 |
|
| 596 |
if (Q->ne[1] <= 32) {
|
| 597 |
constexpr int cols_per_block = 16;
|
|
|
|
| 7 |
#include "fattn-wmma-f16.cuh"
|
| 8 |
|
| 9 |
#ifdef FP16_MMA_AVAILABLE
|
| 10 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
#include <mma.h>
|
| 12 |
+
namespace wmma = nvcuda::wmma;
|
| 13 |
+
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
| 14 |
+
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
| 15 |
+
#include <rocwmma/rocwmma.hpp>
|
| 16 |
+
namespace wmma = rocwmma;
|
| 17 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 18 |
#endif // FP16_MMA_AVAILABLE
|
| 19 |
|
| 20 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 21 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 22 |
+
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
|
|
|
|
|
|
| 23 |
static __global__ void flash_attn_ext_f16(
|
| 24 |
const char * __restrict__ Q,
|
| 25 |
const char * __restrict__ K,
|
|
|
|
| 56 |
const int ne1,
|
| 57 |
const int ne2,
|
| 58 |
const int ne3) {
|
| 59 |
+
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 60 |
// Skip unused kernel variants for faster compilation:
|
| 61 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 62 |
NO_DEVICE_CODE;
|
|
|
|
| 65 |
|
| 66 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 67 |
|
| 68 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 69 |
+
|
| 70 |
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
| 71 |
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 72 |
|
|
|
|
| 75 |
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
| 76 |
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
| 77 |
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
| 78 |
+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
| 79 |
+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
| 80 |
+
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
| 81 |
+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
| 82 |
+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
| 83 |
|
| 84 |
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
| 85 |
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
|
|
|
| 139 |
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 140 |
const int j = j0 + threadIdx.y;
|
| 141 |
#pragma unroll
|
| 142 |
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
| 143 |
const int i = i0 + threadIdx.x;
|
| 144 |
+
if (i0 + warp_size > D/2 && i >= D/2) {
|
| 145 |
break;
|
| 146 |
}
|
| 147 |
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
|
|
|
| 153 |
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 154 |
const int j = j0 + threadIdx.y;
|
| 155 |
#pragma unroll
|
| 156 |
+
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
| 157 |
const int i = i0 + threadIdx.x;
|
| 158 |
+
if (i0 + warp_size > D && i >= D) {
|
| 159 |
break;
|
| 160 |
}
|
| 161 |
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
|
|
|
| 169 |
for (int i0 = 0; i0 < D; i0 += 16) {
|
| 170 |
#pragma unroll
|
| 171 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 172 |
+
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
| 173 |
}
|
| 174 |
}
|
| 175 |
|
|
|
|
| 183 |
frag_c_KQ KQ_c[ncols/frag_n];
|
| 184 |
#pragma unroll
|
| 185 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 186 |
+
wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
|
| 187 |
}
|
| 188 |
#pragma unroll
|
| 189 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 190 |
frag_a_K K_a;
|
| 191 |
+
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 192 |
#pragma unroll
|
| 193 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 194 |
+
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
| 195 |
}
|
| 196 |
}
|
| 197 |
#pragma unroll
|
| 198 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 199 |
+
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
|
| 200 |
}
|
| 201 |
}
|
| 202 |
|
|
|
|
| 209 |
const int j = j0 + threadIdx.y;
|
| 210 |
|
| 211 |
if (std::is_same<KQ_acc_t, float>::value) {
|
| 212 |
+
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
|
| 213 |
#pragma unroll
|
| 214 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
| 215 |
const int k = k0 + threadIdx.x;
|
| 216 |
|
| 217 |
+
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
|
| 218 |
|
| 219 |
if (use_logit_softcap) {
|
| 220 |
+
KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
|
| 221 |
}
|
| 222 |
}
|
| 223 |
|
| 224 |
float KQ_max_new = KQ_max_f[j0/nwarps];
|
| 225 |
#pragma unroll
|
| 226 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
| 227 |
const int k = k0 + threadIdx.x;
|
| 228 |
|
| 229 |
+
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 230 |
+
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
| 231 |
}
|
| 232 |
+
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
| 233 |
|
| 234 |
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
| 235 |
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
|
|
|
| 240 |
|
| 241 |
float KQ_rowsum_add = 0.0f;
|
| 242 |
#pragma unroll
|
| 243 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
| 244 |
const int k = k0 + threadIdx.x;
|
| 245 |
|
| 246 |
+
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
|
| 247 |
+
KQ_f_tmp[k0/warp_size] = expf(diff);
|
| 248 |
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 249 |
+
KQ_f_tmp[k0/warp_size] = 0.0f;
|
| 250 |
}
|
| 251 |
+
KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
|
| 252 |
+
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
|
| 253 |
}
|
| 254 |
+
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
| 255 |
|
| 256 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 257 |
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
| 258 |
} else {
|
| 259 |
+
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
|
| 260 |
#pragma unroll
|
| 261 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
| 262 |
const int k = k0 + threadIdx.x;
|
| 263 |
|
| 264 |
+
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
|
| 265 |
|
| 266 |
if (use_logit_softcap) {
|
| 267 |
// There is no dedicated tangens hyperbolicus function for half2.
|
| 268 |
+
KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
|
| 269 |
+
KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
|
| 270 |
+
/(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
|
| 271 |
|
| 272 |
+
KQ2_tmp[k0/warp_size] *= logit_softcap_2;
|
| 273 |
}
|
| 274 |
}
|
| 275 |
|
| 276 |
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
| 277 |
#pragma unroll
|
| 278 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
| 279 |
const int k = k0 + threadIdx.x;
|
| 280 |
|
| 281 |
+
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 282 |
+
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
| 283 |
}
|
| 284 |
+
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
| 285 |
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
| 286 |
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
| 287 |
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
|
|
|
| 290 |
|
| 291 |
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
| 292 |
#pragma unroll
|
| 293 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
| 294 |
const int k = k0 + threadIdx.x;
|
| 295 |
|
| 296 |
+
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
|
| 297 |
+
KQ2_tmp[k0/warp_size] = h2exp(diff);
|
| 298 |
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 299 |
+
*((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
|
| 300 |
+
KQ_rowsum_add += KQ2_tmp[k0/warp_size];
|
| 301 |
+
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
|
| 302 |
}
|
| 303 |
+
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
| 304 |
|
| 305 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 306 |
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
|
|
|
| 315 |
#pragma unroll
|
| 316 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 317 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 318 |
+
wmma::load_matrix_sync(
|
| 319 |
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
| 320 |
KQ + j0*(kqar*kqs_padded) + k,
|
| 321 |
kqar*kqs_padded);
|
|
|
|
| 327 |
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
| 328 |
#pragma unroll
|
| 329 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 330 |
+
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
|
| 331 |
}
|
| 332 |
|
| 333 |
#pragma unroll
|
|
|
|
| 335 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 336 |
|
| 337 |
frag_a_V v_a;
|
| 338 |
+
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 339 |
#pragma unroll
|
| 340 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 341 |
+
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
| 342 |
}
|
| 343 |
}
|
| 344 |
}
|
|
|
|
| 350 |
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
| 351 |
#pragma unroll
|
| 352 |
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 353 |
+
wmma::store_matrix_sync(
|
| 354 |
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
| 355 |
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
| 356 |
+
D_padded, wmma::mem_col_major);
|
| 357 |
}
|
| 358 |
}
|
| 359 |
|
|
|
|
| 371 |
}
|
| 372 |
|
| 373 |
#pragma unroll
|
| 374 |
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
| 375 |
const int i = i0 + threadIdx.x;
|
| 376 |
+
if (i0 + warp_size > D/2 && i >= D/2) {
|
| 377 |
break;
|
| 378 |
}
|
| 379 |
|
|
|
|
| 405 |
}
|
| 406 |
|
| 407 |
#pragma unroll
|
| 408 |
+
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
| 409 |
const int i = i0 + threadIdx.x;
|
| 410 |
+
if (i0 + warp_size > D && i >= D) {
|
| 411 |
break;
|
| 412 |
}
|
| 413 |
float dst_val = VKQ[j_VKQ*D_padded + i];
|
|
|
|
| 432 |
}
|
| 433 |
#else
|
| 434 |
NO_DEVICE_CODE;
|
| 435 |
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 436 |
}
|
| 437 |
|
| 438 |
constexpr int get_max_power_of_2(int x) {
|
|
|
|
| 522 |
const ggml_tensor * Q = dst->src[0];
|
| 523 |
|
| 524 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 525 |
+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
| 526 |
|
| 527 |
if (prec != GGML_PREC_DEFAULT) {
|
| 528 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
|
|
|
| 579 |
return;
|
| 580 |
}
|
| 581 |
|
| 582 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 583 |
+
if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
|
| 584 |
constexpr int cols_per_block = 8;
|
| 585 |
switch (Q->ne[0]) {
|
| 586 |
case 64:
|
|
|
|
| 601 |
}
|
| 602 |
return;
|
| 603 |
}
|
| 604 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 605 |
|
| 606 |
if (Q->ne[1] <= 32) {
|
| 607 |
constexpr int cols_per_block = 16;
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 250 |
|
| 251 |
ggml_cuda_set_device(ctx.device);
|
| 252 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
|
| 253 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 254 |
|
| 255 |
-
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 256 |
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 258 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 259 |
} else {
|
|
@@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 291 |
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 292 |
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
| 293 |
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
| 294 |
-
if (Q->ne[1] == 1 && Q->ne[0] % (2*
|
| 295 |
if (prec == GGML_PREC_DEFAULT) {
|
| 296 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 297 |
return;
|
|
|
|
| 250 |
|
| 251 |
ggml_cuda_set_device(ctx.device);
|
| 252 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 253 |
+
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 254 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 255 |
|
|
|
|
| 256 |
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
| 257 |
+
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
| 258 |
+
if (fp16_mma_available(cc)) {
|
| 259 |
+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
| 260 |
+
return;
|
| 261 |
+
}
|
| 262 |
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
| 263 |
+
|
| 264 |
+
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 265 |
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 266 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 267 |
} else {
|
|
|
|
| 299 |
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 300 |
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
| 301 |
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
| 302 |
+
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
|
| 303 |
if (prec == GGML_PREC_DEFAULT) {
|
| 304 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 305 |
return;
|
ggml/src/ggml-hip/CMakeLists.txt
CHANGED
|
@@ -39,6 +39,12 @@ endif()
|
|
| 39 |
find_package(hip REQUIRED)
|
| 40 |
find_package(hipblas REQUIRED)
|
| 41 |
find_package(rocblas REQUIRED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
if (${hip_VERSION} VERSION_LESS 5.5)
|
| 44 |
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
|
|
@@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM)
|
|
| 107 |
add_compile_definitions(GGML_HIP_NO_VMM)
|
| 108 |
endif()
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if (NOT GGML_CUDA_FA)
|
| 111 |
add_compile_definitions(GGML_CUDA_NO_FA)
|
| 112 |
endif()
|
|
|
|
| 39 |
find_package(hip REQUIRED)
|
| 40 |
find_package(hipblas REQUIRED)
|
| 41 |
find_package(rocblas REQUIRED)
|
| 42 |
+
if (GGML_HIP_ROCWMMA_FATTN)
|
| 43 |
+
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
|
| 44 |
+
if (NOT ${FOUND_ROCWMMA})
|
| 45 |
+
message(FATAL_ERROR "rocwmma has not been found")
|
| 46 |
+
endif()
|
| 47 |
+
endif()
|
| 48 |
|
| 49 |
if (${hip_VERSION} VERSION_LESS 5.5)
|
| 50 |
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
|
|
|
|
| 113 |
add_compile_definitions(GGML_HIP_NO_VMM)
|
| 114 |
endif()
|
| 115 |
|
| 116 |
+
if (GGML_HIP_ROCWMMA_FATTN)
|
| 117 |
+
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
| 118 |
+
endif()
|
| 119 |
+
|
| 120 |
if (NOT GGML_CUDA_FA)
|
| 121 |
add_compile_definitions(GGML_CUDA_NO_FA)
|
| 122 |
endif()
|