Spaces:
Running
Running
ggml: cache sin/cos for RoPE (llama/4908)
Browse files
ggml.c
CHANGED
|
@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
|
|
| 11638 |
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
| 11639 |
}
|
| 11640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11641 |
void ggml_rope_yarn_corr_dims(
|
| 11642 |
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
| 11643 |
) {
|
|
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
|
|
| 11720 |
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 11721 |
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
| 11722 |
const int64_t p = pos[i2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11723 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 11724 |
if (ir++ < ir0) continue;
|
| 11725 |
if (ir > ir1) break;
|
|
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
|
|
| 11753 |
}
|
| 11754 |
} else if (!is_neox) {
|
| 11755 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 11756 |
-
float cos_theta
|
| 11757 |
-
|
| 11758 |
-
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
| 11759 |
-
);
|
| 11760 |
-
sin_theta *= sin_sign;
|
| 11761 |
|
| 11762 |
// zeta scaling for xPos only:
|
| 11763 |
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
| 11764 |
if (xpos_down) zeta = 1.0f / zeta;
|
| 11765 |
|
| 11766 |
-
theta_base *= theta_scale;
|
| 11767 |
-
|
| 11768 |
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 11769 |
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 11770 |
|
|
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
|
|
| 11888 |
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 11889 |
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
| 11890 |
const int64_t p = pos[i2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11891 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 11892 |
if (ir++ < ir0) continue;
|
| 11893 |
if (ir > ir1) break;
|
|
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
|
|
| 11921 |
}
|
| 11922 |
} else if (!is_neox) {
|
| 11923 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 11924 |
-
float cos_theta
|
| 11925 |
-
|
| 11926 |
-
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
| 11927 |
-
);
|
| 11928 |
-
sin_theta *= sin_sign;
|
| 11929 |
-
|
| 11930 |
-
theta_base *= theta_scale;
|
| 11931 |
|
| 11932 |
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 11933 |
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 16722 |
}
|
| 16723 |
} break;
|
| 16724 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 16725 |
{
|
| 16726 |
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
| 16727 |
} break;
|
|
|
|
| 11638 |
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
| 11639 |
}
|
| 11640 |
|
| 11641 |
+
static void ggml_rope_cache_init(
|
| 11642 |
+
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
| 11643 |
+
float * cache, float sin_sign, float theta_scale
|
| 11644 |
+
) {
|
| 11645 |
+
float theta = theta_base;
|
| 11646 |
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 11647 |
+
rope_yarn(
|
| 11648 |
+
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
| 11649 |
+
);
|
| 11650 |
+
cache[i0 + 1] *= sin_sign;
|
| 11651 |
+
|
| 11652 |
+
theta *= theta_scale;
|
| 11653 |
+
}
|
| 11654 |
+
}
|
| 11655 |
+
|
| 11656 |
void ggml_rope_yarn_corr_dims(
|
| 11657 |
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
| 11658 |
) {
|
|
|
|
| 11735 |
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 11736 |
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
| 11737 |
const int64_t p = pos[i2];
|
| 11738 |
+
|
| 11739 |
+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 11740 |
+
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
| 11741 |
+
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
| 11742 |
+
}
|
| 11743 |
+
|
| 11744 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 11745 |
if (ir++ < ir0) continue;
|
| 11746 |
if (ir > ir1) break;
|
|
|
|
| 11774 |
}
|
| 11775 |
} else if (!is_neox) {
|
| 11776 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 11777 |
+
const float cos_theta = cache[i0 + 0];
|
| 11778 |
+
const float sin_theta = cache[i0 + 1];
|
|
|
|
|
|
|
|
|
|
| 11779 |
|
| 11780 |
// zeta scaling for xPos only:
|
| 11781 |
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
| 11782 |
if (xpos_down) zeta = 1.0f / zeta;
|
| 11783 |
|
|
|
|
|
|
|
| 11784 |
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 11785 |
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 11786 |
|
|
|
|
| 11904 |
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 11905 |
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
| 11906 |
const int64_t p = pos[i2];
|
| 11907 |
+
|
| 11908 |
+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 11909 |
+
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
| 11910 |
+
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
| 11911 |
+
}
|
| 11912 |
+
|
| 11913 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 11914 |
if (ir++ < ir0) continue;
|
| 11915 |
if (ir > ir1) break;
|
|
|
|
| 11943 |
}
|
| 11944 |
} else if (!is_neox) {
|
| 11945 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 11946 |
+
const float cos_theta = cache[i0 + 0];
|
| 11947 |
+
const float sin_theta = cache[i0 + 1];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11948 |
|
| 11949 |
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 11950 |
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
| 16739 |
}
|
| 16740 |
} break;
|
| 16741 |
case GGML_OP_SOFT_MAX:
|
| 16742 |
+
case GGML_OP_ROPE:
|
| 16743 |
{
|
| 16744 |
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
| 16745 |
} break;
|