Spaces:
Running
Running
Justine Tunney
commited on
Commit
·
c78b872
1
Parent(s):
b441739
ggml : rewrite silu and softmax for cpu (llama/7154)
Browse filesThis change upstreams llamafile's vectorized expf() functions. This lets
us compute softmax and silu more accurately than the short[65536] lookup
table that GGML previously used to make this operation go faster. We can
support aarch64 and sse2+ with the worst case rounding error of 2ulp. It
makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf
go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512
ggml.c
CHANGED
|
@@ -165,9 +165,6 @@ void ggml_print_backtrace(void) {
|
|
| 165 |
#define GGML_DEBUG 0
|
| 166 |
#define GGML_GELU_FP16
|
| 167 |
#define GGML_GELU_QUICK_FP16
|
| 168 |
-
#define GGML_SILU_FP16
|
| 169 |
-
// #define GGML_CROSS_ENTROPY_EXP_FP16
|
| 170 |
-
// #define GGML_FLASH_ATTN_EXP_FP16
|
| 171 |
|
| 172 |
#define GGML_SOFT_MAX_UNROLL 4
|
| 173 |
#define GGML_VEC_DOT_UNROLL 2
|
|
@@ -318,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
|
|
| 318 |
// precomputed quick gelu table for f16 (128 KB)
|
| 319 |
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
|
| 320 |
|
| 321 |
-
// precomputed silu table for f16 (128 KB)
|
| 322 |
-
static ggml_fp16_t ggml_table_silu_f16[1 << 16];
|
| 323 |
-
|
| 324 |
-
// precomputed exp table for f16 (128 KB)
|
| 325 |
-
static ggml_fp16_t ggml_table_exp_f16[1 << 16];
|
| 326 |
-
|
| 327 |
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
| 328 |
float ggml_table_f32_f16[1 << 16];
|
| 329 |
|
|
@@ -2085,52 +2076,291 @@ inline static float ggml_silu_f32(float x) {
|
|
| 2085 |
return x/(1.0f + expf(-x));
|
| 2086 |
}
|
| 2087 |
|
| 2088 |
-
|
| 2089 |
-
// const uint16_t * i16 = (const uint16_t *) x;
|
| 2090 |
-
// for (int i = 0; i < n; ++i) {
|
| 2091 |
-
// y[i] = ggml_table_silu_f16[i16[i]];
|
| 2092 |
-
// }
|
| 2093 |
-
//}
|
| 2094 |
|
| 2095 |
-
|
| 2096 |
-
|
| 2097 |
-
|
| 2098 |
-
|
| 2099 |
-
|
| 2100 |
-
|
| 2101 |
-
|
| 2102 |
-
|
| 2103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2104 |
#else
|
| 2105 |
-
|
| 2106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2107 |
y[i] = ggml_silu_f32(x[i]);
|
| 2108 |
}
|
| 2109 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2110 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2111 |
|
| 2112 |
inline static float ggml_silu_backward_f32(float x, float dy) {
|
| 2113 |
const float s = 1.0f/(1.0f + expf(-x));
|
| 2114 |
return dy*s*(1.0f + x*(1.0f - s));
|
| 2115 |
}
|
| 2116 |
|
| 2117 |
-
#ifdef GGML_SILU_FP16
|
| 2118 |
-
inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
|
| 2119 |
-
for (int i = 0; i < n; ++i) {
|
| 2120 |
-
// we did not use x[i] to compute forward silu but its f16 equivalent
|
| 2121 |
-
// take derivative at f16 of x[i]:
|
| 2122 |
-
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
| 2123 |
-
float usedx = GGML_FP16_TO_FP32(fp16);
|
| 2124 |
-
dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
|
| 2125 |
-
}
|
| 2126 |
-
}
|
| 2127 |
-
#else
|
| 2128 |
inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
|
| 2129 |
for (int i = 0; i < n; ++i) {
|
| 2130 |
dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
|
| 2131 |
}
|
| 2132 |
}
|
| 2133 |
-
#endif
|
| 2134 |
|
| 2135 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 2136 |
#ifndef GGML_USE_ACCELERATE
|
|
@@ -2922,8 +3152,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
| 2922 |
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
| 2923 |
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
| 2924 |
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
| 2925 |
-
ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
|
| 2926 |
-
ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
|
| 2927 |
}
|
| 2928 |
|
| 2929 |
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
|
@@ -13600,22 +13828,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 13600 |
float max = -INFINITY;
|
| 13601 |
ggml_vec_max_f32(nc, &max, wp);
|
| 13602 |
|
| 13603 |
-
ggml_float sum =
|
| 13604 |
-
|
| 13605 |
-
uint16_t scvt;
|
| 13606 |
-
for (int i = 0; i < nc; i++) {
|
| 13607 |
-
if (wp[i] == -INFINITY) {
|
| 13608 |
-
dp[i] = 0.0f;
|
| 13609 |
-
} else {
|
| 13610 |
-
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
|
| 13611 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
|
| 13612 |
-
memcpy(&scvt, &s, sizeof(scvt));
|
| 13613 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
| 13614 |
-
sum += (ggml_float)val;
|
| 13615 |
-
dp[i] = val;
|
| 13616 |
-
}
|
| 13617 |
-
}
|
| 13618 |
-
|
| 13619 |
assert(sum > 0.0);
|
| 13620 |
|
| 13621 |
sum = 1.0/sum;
|
|
@@ -15374,37 +15587,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
| 15374 |
vvexpf(S, S, &Mup);
|
| 15375 |
ggml_vec_sum_f32(Mup, &sum, S);
|
| 15376 |
#else
|
| 15377 |
-
|
| 15378 |
-
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
| 15379 |
-
|
| 15380 |
-
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
| 15381 |
-
if (i >= masked_begin) {
|
| 15382 |
-
break;
|
| 15383 |
-
}
|
| 15384 |
-
float * SS = S + i;
|
| 15385 |
-
|
| 15386 |
-
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
| 15387 |
-
if (i + j >= masked_begin) {
|
| 15388 |
-
break;
|
| 15389 |
-
} else if (SS[j] == -INFINITY) {
|
| 15390 |
-
SS[j] = 0.0f;
|
| 15391 |
-
} else {
|
| 15392 |
-
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
| 15393 |
-
const float val = expf(SS[j] - max);
|
| 15394 |
-
#else
|
| 15395 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
| 15396 |
-
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
| 15397 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
|
| 15398 |
-
#endif
|
| 15399 |
-
sump[j] += (ggml_float)val;
|
| 15400 |
-
SS[j] = val;
|
| 15401 |
-
}
|
| 15402 |
-
}
|
| 15403 |
-
}
|
| 15404 |
-
|
| 15405 |
-
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
| 15406 |
-
sum += sump[i];
|
| 15407 |
-
}
|
| 15408 |
#endif
|
| 15409 |
}
|
| 15410 |
|
|
@@ -15586,28 +15769,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
| 15586 |
vvexpf(S, S, &Mup);
|
| 15587 |
ggml_vec_sum_f32(Mup, &sum, S);
|
| 15588 |
#else
|
| 15589 |
-
|
| 15590 |
-
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
| 15591 |
-
|
| 15592 |
-
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
| 15593 |
-
float * SS = S + i;
|
| 15594 |
-
|
| 15595 |
-
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
| 15596 |
-
if (SS[j] == -INFINITY) {
|
| 15597 |
-
SS[j] = 0.0f;
|
| 15598 |
-
} else {
|
| 15599 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
| 15600 |
-
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
| 15601 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
|
| 15602 |
-
sump[j] += (ggml_float)val;
|
| 15603 |
-
SS[j] = val;
|
| 15604 |
-
}
|
| 15605 |
-
}
|
| 15606 |
-
}
|
| 15607 |
-
|
| 15608 |
-
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
| 15609 |
-
sum += sump[i];
|
| 15610 |
-
}
|
| 15611 |
#endif
|
| 15612 |
}
|
| 15613 |
|
|
@@ -16234,38 +16396,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|
| 16234 |
vvexpf(SM, SM, &Mup);
|
| 16235 |
ggml_vec_sum_f32(Mup, &sum, SM);
|
| 16236 |
#else
|
| 16237 |
-
|
| 16238 |
-
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
| 16239 |
-
|
| 16240 |
-
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
| 16241 |
-
if (i >= masked_begin) {
|
| 16242 |
-
break;
|
| 16243 |
-
}
|
| 16244 |
-
float * SR = S + i;
|
| 16245 |
-
float * SW = SM + i;
|
| 16246 |
-
|
| 16247 |
-
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
| 16248 |
-
if (i + j >= masked_begin) {
|
| 16249 |
-
break;
|
| 16250 |
-
} else if (SR[j] == -INFINITY) {
|
| 16251 |
-
SW[j] = 0.0f;
|
| 16252 |
-
} else {
|
| 16253 |
-
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
| 16254 |
-
const float val = expf(SR[j] - max);
|
| 16255 |
-
#else
|
| 16256 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
| 16257 |
-
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
| 16258 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
|
| 16259 |
-
#endif
|
| 16260 |
-
sump[j] += (ggml_float)val;
|
| 16261 |
-
SW[j] = val;
|
| 16262 |
-
}
|
| 16263 |
-
}
|
| 16264 |
-
}
|
| 16265 |
-
|
| 16266 |
-
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
| 16267 |
-
sum += sump[i];
|
| 16268 |
-
}
|
| 16269 |
#endif
|
| 16270 |
}
|
| 16271 |
|
|
@@ -17291,35 +17422,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
| 17291 |
assert(!isnan(s1[i]));
|
| 17292 |
}
|
| 17293 |
#endif
|
| 17294 |
-
// soft_max
|
| 17295 |
-
ggml_float sum = 0.0;
|
| 17296 |
-
{
|
| 17297 |
-
float max = -INFINITY;
|
| 17298 |
-
ggml_vec_max_f32(nc, &max, s0);
|
| 17299 |
|
| 17300 |
-
|
| 17301 |
-
|
| 17302 |
-
|
| 17303 |
-
|
| 17304 |
-
|
| 17305 |
-
|
| 17306 |
-
const float s = s0[i] - max;
|
| 17307 |
-
const float val = expf(s);
|
| 17308 |
-
#else
|
| 17309 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
| 17310 |
-
memcpy(&scvt, &s, sizeof(scvt));
|
| 17311 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
| 17312 |
-
#endif
|
| 17313 |
-
sum += (ggml_float)val;
|
| 17314 |
-
st[i] = val;
|
| 17315 |
-
}
|
| 17316 |
-
}
|
| 17317 |
|
| 17318 |
-
assert(sum > 0.0);
|
| 17319 |
-
// sum = 1.0/sum;
|
| 17320 |
-
}
|
| 17321 |
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
| 17322 |
-
sum = (1.0 - eps) / sum;
|
| 17323 |
ggml_vec_scale_f32(nc, st, sum);
|
| 17324 |
ggml_vec_add1_f32(nc, st, st, eps);
|
| 17325 |
ggml_vec_log_f32(nc, st, st);
|
|
@@ -17409,32 +17520,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
| 17409 |
#endif
|
| 17410 |
|
| 17411 |
// soft_max
|
| 17412 |
-
|
| 17413 |
-
|
| 17414 |
-
|
| 17415 |
-
|
| 17416 |
-
|
| 17417 |
-
uint16_t scvt; UNUSED(scvt);
|
| 17418 |
-
for (int i = 0; i < nc; i++) {
|
| 17419 |
-
if (s0[i] == -INFINITY) {
|
| 17420 |
-
ds0[i] = 0.0f;
|
| 17421 |
-
} else {
|
| 17422 |
-
#ifndef GGML_CROSS_ENTROPY_EXP_FP16
|
| 17423 |
-
const float s = s0[i] - max;
|
| 17424 |
-
const float val = expf(s);
|
| 17425 |
-
#else
|
| 17426 |
-
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
| 17427 |
-
memcpy(&scvt, &s, sizeof(scvt));
|
| 17428 |
-
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
| 17429 |
-
#endif
|
| 17430 |
-
sum += (ggml_float)val;
|
| 17431 |
-
ds0[i] = val;
|
| 17432 |
-
}
|
| 17433 |
-
}
|
| 17434 |
-
|
| 17435 |
-
assert(sum > 0.0);
|
| 17436 |
-
sum = (1.0 - eps)/sum;
|
| 17437 |
-
}
|
| 17438 |
|
| 17439 |
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
| 17440 |
ggml_vec_scale_f32(nc, ds0, sum);
|
|
|
|
| 165 |
#define GGML_DEBUG 0
|
| 166 |
#define GGML_GELU_FP16
|
| 167 |
#define GGML_GELU_QUICK_FP16
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
#define GGML_SOFT_MAX_UNROLL 4
|
| 170 |
#define GGML_VEC_DOT_UNROLL 2
|
|
|
|
| 315 |
// precomputed quick gelu table for f16 (128 KB)
|
| 316 |
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
| 319 |
float ggml_table_f32_f16[1 << 16];
|
| 320 |
|
|
|
|
| 2076 |
return x/(1.0f + expf(-x));
|
| 2077 |
}
|
| 2078 |
|
| 2079 |
+
#if defined(__ARM_NEON)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2080 |
|
| 2081 |
+
// adapted from arm limited optimized routine
|
| 2082 |
+
// the maximum error is 1.45358 plus 0.5 ulps
|
| 2083 |
+
// numbers above 88.38 will flush to infinity
|
| 2084 |
+
// numbers beneath -103.97 will flush to zero
|
| 2085 |
+
inline static float32x4_t ggml_v_expf(float32x4_t x) {
|
| 2086 |
+
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
|
| 2087 |
+
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
|
| 2088 |
+
const float32x4_t n = vsubq_f32(z, r);
|
| 2089 |
+
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
|
| 2090 |
+
vdupq_n_f32(0x1.7f7d1cp-20f));
|
| 2091 |
+
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
|
| 2092 |
+
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
|
| 2093 |
+
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
|
| 2094 |
+
const float32x4_t u = vmulq_f32(b, b);
|
| 2095 |
+
const float32x4_t j = vfmaq_f32(
|
| 2096 |
+
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
|
| 2097 |
+
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
|
| 2098 |
+
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
|
| 2099 |
+
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
|
| 2100 |
+
return vfmaq_f32(k, j, k);
|
| 2101 |
+
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
|
| 2102 |
+
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
|
| 2103 |
+
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
|
| 2104 |
+
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
| 2105 |
+
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
| 2106 |
+
}
|
| 2107 |
+
|
| 2108 |
+
// computes silu x/(1+exp(-x)) in single precision vector
|
| 2109 |
+
inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
| 2110 |
+
const float32x4_t one = vdupq_n_f32(1.0f);
|
| 2111 |
+
const float32x4_t zero = vdupq_n_f32(0.0f);
|
| 2112 |
+
const float32x4_t neg_x = vsubq_f32(zero, x);
|
| 2113 |
+
const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
|
| 2114 |
+
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
| 2115 |
+
return vdivq_f32(x, one_plus_exp_neg_x);
|
| 2116 |
+
}
|
| 2117 |
+
|
| 2118 |
+
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
| 2119 |
+
|
| 2120 |
+
// adapted from arm limited optimized routine
|
| 2121 |
+
// the maximum error is 1.45358 plus 0.5 ulps
|
| 2122 |
+
// numbers above 88.38 will flush to infinity
|
| 2123 |
+
// numbers beneath -103.97 will flush to zero
|
| 2124 |
+
inline static __m512 ggml_v_expf(__m512 x) {
|
| 2125 |
+
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
| 2126 |
+
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
| 2127 |
+
const __m512 n = _mm512_sub_ps(z, r);
|
| 2128 |
+
const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
| 2129 |
+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
| 2130 |
+
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
|
| 2131 |
+
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
|
| 2132 |
+
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
|
| 2133 |
+
const __m512 u = _mm512_mul_ps(b, b);
|
| 2134 |
+
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
| 2135 |
+
_mm512_set1_ps(0x1.573e2ep-5f)), u,
|
| 2136 |
+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
| 2137 |
+
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
| 2138 |
+
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
|
| 2139 |
+
if (_mm512_kortestz(c, c))
|
| 2140 |
+
return _mm512_fmadd_ps(j, k, k);
|
| 2141 |
+
const __m512i g = _mm512_and_si512(
|
| 2142 |
+
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
|
| 2143 |
+
_mm512_set1_epi32(0x82000000u));
|
| 2144 |
+
const __m512 s1 =
|
| 2145 |
+
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
|
| 2146 |
+
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
|
| 2147 |
+
const __mmask16 d =
|
| 2148 |
+
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
| 2149 |
+
return _mm512_mask_blend_ps(
|
| 2150 |
+
d, _mm512_mask_blend_ps(
|
| 2151 |
+
c, _mm512_fmadd_ps(k, j, k),
|
| 2152 |
+
_mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
|
| 2153 |
+
_mm512_mul_ps(s1, s1));
|
| 2154 |
+
}
|
| 2155 |
+
|
| 2156 |
+
// computes silu x/(1+exp(-x)) in single precision vector
|
| 2157 |
+
inline static __m512 ggml_v_silu(__m512 x) {
|
| 2158 |
+
const __m512 one = _mm512_set1_ps(1);
|
| 2159 |
+
const __m512 zero = _mm512_setzero_ps();
|
| 2160 |
+
const __m512 neg_x = _mm512_sub_ps(zero, x);
|
| 2161 |
+
const __m512 exp_neg_x = ggml_v_expf(neg_x);
|
| 2162 |
+
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
| 2163 |
+
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
| 2164 |
+
}
|
| 2165 |
+
|
| 2166 |
+
#elif defined(__AVX2__) && defined(__FMA__)
|
| 2167 |
+
|
| 2168 |
+
// adapted from arm limited optimized routine
|
| 2169 |
+
// the maximum error is 1.45358 plus 0.5 ulps
|
| 2170 |
+
// numbers above 88.38 will flush to infinity
|
| 2171 |
+
// numbers beneath -103.97 will flush to zero
|
| 2172 |
+
inline static __m256 ggml_v_expf(__m256 x) {
|
| 2173 |
+
const __m256 r = _mm256_set1_ps(0x1.8p23f);
|
| 2174 |
+
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
|
| 2175 |
+
const __m256 n = _mm256_sub_ps(z, r);
|
| 2176 |
+
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
|
| 2177 |
+
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
|
| 2178 |
+
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
|
| 2179 |
+
const __m256 k = _mm256_castsi256_ps(
|
| 2180 |
+
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
|
| 2181 |
+
const __m256i c = _mm256_castps_si256(
|
| 2182 |
+
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
| 2183 |
+
_mm256_set1_ps(126), _CMP_GT_OQ));
|
| 2184 |
+
const __m256 u = _mm256_mul_ps(b, b);
|
| 2185 |
+
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
|
| 2186 |
+
_mm256_set1_ps(0x1.573e2ep-5f)), u,
|
| 2187 |
+
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
|
| 2188 |
+
_mm256_set1_ps(0x1.fffdb6p-2f))),
|
| 2189 |
+
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
|
| 2190 |
+
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
|
| 2191 |
+
return _mm256_fmadd_ps(j, k, k);
|
| 2192 |
+
const __m256i g = _mm256_and_si256(
|
| 2193 |
+
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
|
| 2194 |
+
_mm256_set1_epi32(0x82000000u));
|
| 2195 |
+
const __m256 s1 =
|
| 2196 |
+
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
|
| 2197 |
+
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
|
| 2198 |
+
const __m256i d = _mm256_castps_si256(
|
| 2199 |
+
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
| 2200 |
+
_mm256_set1_ps(192), _CMP_GT_OQ));
|
| 2201 |
+
return _mm256_or_ps(
|
| 2202 |
+
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
|
| 2203 |
+
_mm256_andnot_ps(
|
| 2204 |
+
_mm256_castsi256_ps(d),
|
| 2205 |
+
_mm256_or_ps(
|
| 2206 |
+
_mm256_and_ps(_mm256_castsi256_ps(c),
|
| 2207 |
+
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
|
| 2208 |
+
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
| 2209 |
+
}
|
| 2210 |
+
|
| 2211 |
+
// computes silu x/(1+exp(-x)) in single precision vector
|
| 2212 |
+
inline static __m256 ggml_v_silu(__m256 x) {
|
| 2213 |
+
const __m256 one = _mm256_set1_ps(1);
|
| 2214 |
+
const __m256 zero = _mm256_setzero_ps();
|
| 2215 |
+
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
| 2216 |
+
const __m256 exp_neg_x = ggml_v_expf(neg_x);
|
| 2217 |
+
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
| 2218 |
+
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
| 2219 |
+
}
|
| 2220 |
+
|
| 2221 |
+
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
|
| 2222 |
+
|
| 2223 |
+
#if defined(__FMA__)
|
| 2224 |
+
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
|
| 2225 |
+
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
|
| 2226 |
#else
|
| 2227 |
+
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
|
| 2228 |
+
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
|
| 2229 |
+
#endif
|
| 2230 |
+
|
| 2231 |
+
// adapted from arm limited optimized routine
|
| 2232 |
+
// the maximum error is 1.45358 plus 0.5 ulps
|
| 2233 |
+
// numbers above 88.38 will flush to infinity
|
| 2234 |
+
// numbers beneath -103.97 will flush to zero
|
| 2235 |
+
inline static __m128 ggml_v_expf(__m128 x) {
|
| 2236 |
+
const __m128 r = _mm_set1_ps(0x1.8p23f);
|
| 2237 |
+
const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
|
| 2238 |
+
const __m128 n = _mm_sub_ps(z, r);
|
| 2239 |
+
const __m128 b =
|
| 2240 |
+
NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
|
| 2241 |
+
const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
|
| 2242 |
+
const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
|
| 2243 |
+
const __m128i c =
|
| 2244 |
+
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
|
| 2245 |
+
const __m128 u = _mm_mul_ps(b, b);
|
| 2246 |
+
const __m128 j =
|
| 2247 |
+
MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
|
| 2248 |
+
MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
|
| 2249 |
+
u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
|
| 2250 |
+
if (!_mm_movemask_epi8(c))
|
| 2251 |
+
return MADD128(j, k, k);
|
| 2252 |
+
const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
|
| 2253 |
+
_mm_set1_epi32(0x82000000u));
|
| 2254 |
+
const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
|
| 2255 |
+
const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
|
| 2256 |
+
const __m128i d =
|
| 2257 |
+
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
|
| 2258 |
+
return _mm_or_ps(
|
| 2259 |
+
_mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
|
| 2260 |
+
_mm_andnot_ps(_mm_castsi128_ps(d),
|
| 2261 |
+
_mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
|
| 2262 |
+
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
|
| 2263 |
+
}
|
| 2264 |
+
|
| 2265 |
+
// computes silu x/(1+exp(-x)) in single precision vector
|
| 2266 |
+
inline static __m128 ggml_v_silu(__m128 x) {
|
| 2267 |
+
const __m128 one = _mm_set1_ps(1);
|
| 2268 |
+
const __m128 zero = _mm_setzero_ps();
|
| 2269 |
+
const __m128 neg_x = _mm_sub_ps(zero, x);
|
| 2270 |
+
const __m128 exp_neg_x = ggml_v_expf(neg_x);
|
| 2271 |
+
const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
|
| 2272 |
+
return _mm_div_ps(x, one_plus_exp_neg_x);
|
| 2273 |
+
}
|
| 2274 |
+
|
| 2275 |
+
#endif // __ARM_NEON / __AVX2__ / __SSE2__
|
| 2276 |
+
|
| 2277 |
+
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
| 2278 |
+
int i = 0;
|
| 2279 |
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
| 2280 |
+
for (; i + 15 < n; i += 16) {
|
| 2281 |
+
_mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
|
| 2282 |
+
}
|
| 2283 |
+
#elif defined(__AVX2__) && defined(__FMA__)
|
| 2284 |
+
for (; i + 7 < n; i += 8) {
|
| 2285 |
+
_mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
|
| 2286 |
+
}
|
| 2287 |
+
#elif defined(__SSE2__)
|
| 2288 |
+
for (; i + 3 < n; i += 4) {
|
| 2289 |
+
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
| 2290 |
+
}
|
| 2291 |
+
#elif defined(__ARM_NEON)
|
| 2292 |
+
for (; i + 3 < n; i += 4) {
|
| 2293 |
+
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
| 2294 |
+
}
|
| 2295 |
+
#endif
|
| 2296 |
+
for (; i < n; ++i) {
|
| 2297 |
y[i] = ggml_silu_f32(x[i]);
|
| 2298 |
}
|
| 2299 |
}
|
| 2300 |
+
|
| 2301 |
+
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
| 2302 |
+
int i = 0;
|
| 2303 |
+
ggml_float sum = 0;
|
| 2304 |
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
| 2305 |
+
for (; i + 15 < n; i += 16) {
|
| 2306 |
+
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
|
| 2307 |
+
_mm512_set1_ps(max)));
|
| 2308 |
+
_mm512_storeu_ps(y + i, val);
|
| 2309 |
+
sum += (ggml_float)_mm512_reduce_add_ps(val);
|
| 2310 |
+
}
|
| 2311 |
+
#elif defined(__AVX2__) && defined(__FMA__)
|
| 2312 |
+
for (; i + 7 < n; i += 8) {
|
| 2313 |
+
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
|
| 2314 |
+
_mm256_set1_ps(max)));
|
| 2315 |
+
_mm256_storeu_ps(y + i, val);
|
| 2316 |
+
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
|
| 2317 |
+
_mm256_castps256_ps128(val));
|
| 2318 |
+
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
| 2319 |
+
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
| 2320 |
+
sum += (ggml_float)_mm_cvtss_f32(val2);
|
| 2321 |
+
}
|
| 2322 |
+
#elif defined(__SSE2__)
|
| 2323 |
+
for (; i + 3 < n; i += 4) {
|
| 2324 |
+
__m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
|
| 2325 |
+
_mm_set1_ps(max)));
|
| 2326 |
+
_mm_storeu_ps(y + i, val);
|
| 2327 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 2328 |
+
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
| 2329 |
+
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
| 2330 |
+
#else
|
| 2331 |
+
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
| 2332 |
+
val = _mm_add_ps(val, tmp);
|
| 2333 |
+
tmp = _mm_movehl_ps(tmp, val);
|
| 2334 |
+
val = _mm_add_ss(val, tmp);
|
| 2335 |
#endif
|
| 2336 |
+
sum += (ggml_float)_mm_cvtss_f32(val);
|
| 2337 |
+
}
|
| 2338 |
+
#elif defined(__ARM_NEON)
|
| 2339 |
+
for (; i + 3 < n; i += 4) {
|
| 2340 |
+
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
| 2341 |
+
vdupq_n_f32(max)));
|
| 2342 |
+
vst1q_f32(y + i, val);
|
| 2343 |
+
sum += (ggml_float)vaddvq_f32(val);
|
| 2344 |
+
}
|
| 2345 |
+
#endif
|
| 2346 |
+
for (; i < n; ++i) {
|
| 2347 |
+
float val = expf(x[i] - max);
|
| 2348 |
+
sum += (ggml_float)val;
|
| 2349 |
+
y[i] = val;
|
| 2350 |
+
}
|
| 2351 |
+
return sum;
|
| 2352 |
+
}
|
| 2353 |
|
| 2354 |
inline static float ggml_silu_backward_f32(float x, float dy) {
|
| 2355 |
const float s = 1.0f/(1.0f + expf(-x));
|
| 2356 |
return dy*s*(1.0f + x*(1.0f - s));
|
| 2357 |
}
|
| 2358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2359 |
inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
|
| 2360 |
for (int i = 0; i < n; ++i) {
|
| 2361 |
dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
|
| 2362 |
}
|
| 2363 |
}
|
|
|
|
| 2364 |
|
| 2365 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 2366 |
#ifndef GGML_USE_ACCELERATE
|
|
|
|
| 3152 |
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
| 3153 |
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
| 3154 |
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
|
|
|
|
|
|
| 3155 |
}
|
| 3156 |
|
| 3157 |
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
|
|
|
| 13828 |
float max = -INFINITY;
|
| 13829 |
ggml_vec_max_f32(nc, &max, wp);
|
| 13830 |
|
| 13831 |
+
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13832 |
assert(sum > 0.0);
|
| 13833 |
|
| 13834 |
sum = 1.0/sum;
|
|
|
|
| 15587 |
vvexpf(S, S, &Mup);
|
| 15588 |
ggml_vec_sum_f32(Mup, &sum, S);
|
| 15589 |
#else
|
| 15590 |
+
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15591 |
#endif
|
| 15592 |
}
|
| 15593 |
|
|
|
|
| 15769 |
vvexpf(S, S, &Mup);
|
| 15770 |
ggml_vec_sum_f32(Mup, &sum, S);
|
| 15771 |
#else
|
| 15772 |
+
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15773 |
#endif
|
| 15774 |
}
|
| 15775 |
|
|
|
|
| 16396 |
vvexpf(SM, SM, &Mup);
|
| 16397 |
ggml_vec_sum_f32(Mup, &sum, SM);
|
| 16398 |
#else
|
| 16399 |
+
sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16400 |
#endif
|
| 16401 |
}
|
| 16402 |
|
|
|
|
| 17422 |
assert(!isnan(s1[i]));
|
| 17423 |
}
|
| 17424 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17425 |
|
| 17426 |
+
// soft_max
|
| 17427 |
+
float max = -INFINITY;
|
| 17428 |
+
ggml_vec_max_f32(nc, &max, s0);
|
| 17429 |
+
ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
|
| 17430 |
+
assert(sum > 0.0);
|
| 17431 |
+
sum = (1.0 - eps) / sum;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17432 |
|
|
|
|
|
|
|
|
|
|
| 17433 |
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
|
|
|
| 17434 |
ggml_vec_scale_f32(nc, st, sum);
|
| 17435 |
ggml_vec_add1_f32(nc, st, st, eps);
|
| 17436 |
ggml_vec_log_f32(nc, st, st);
|
|
|
|
| 17520 |
#endif
|
| 17521 |
|
| 17522 |
// soft_max
|
| 17523 |
+
float max = -INFINITY;
|
| 17524 |
+
ggml_vec_max_f32(nc, &max, s0);
|
| 17525 |
+
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
| 17526 |
+
assert(sum > 0.0);
|
| 17527 |
+
sum = (1.0 - eps) / sum;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17528 |
|
| 17529 |
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
| 17530 |
ggml_vec_scale_f32(nc, ds0, sum);
|