JohannesGaessler commited on
Commit
c315fbf
·
unverified ·
1 Parent(s): d6abb6a

ggml: cache sin/cos for RoPE (llama/4908)

Browse files
Files changed (1) hide show
  1. ggml.c +32 -14
ggml.c CHANGED
@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
11638
  return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
11639
  }
11640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11641
  void ggml_rope_yarn_corr_dims(
11642
  int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
11643
  ) {
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
11720
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11721
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11722
  const int64_t p = pos[i2];
 
 
 
 
 
 
11723
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11724
  if (ir++ < ir0) continue;
11725
  if (ir > ir1) break;
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
11753
  }
11754
  } else if (!is_neox) {
11755
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11756
- float cos_theta, sin_theta;
11757
- rope_yarn(
11758
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11759
- );
11760
- sin_theta *= sin_sign;
11761
 
11762
  // zeta scaling for xPos only:
11763
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11764
  if (xpos_down) zeta = 1.0f / zeta;
11765
 
11766
- theta_base *= theta_scale;
11767
-
11768
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11769
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11770
 
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
11888
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11889
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11890
  const int64_t p = pos[i2];
 
 
 
 
 
 
11891
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11892
  if (ir++ < ir0) continue;
11893
  if (ir > ir1) break;
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
11921
  }
11922
  } else if (!is_neox) {
11923
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11924
- float cos_theta, sin_theta;
11925
- rope_yarn(
11926
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11927
- );
11928
- sin_theta *= sin_sign;
11929
-
11930
- theta_base *= theta_scale;
11931
 
11932
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11933
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
16722
  }
16723
  } break;
16724
  case GGML_OP_SOFT_MAX:
 
16725
  {
16726
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16727
  } break;
 
11638
  return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
11639
  }
11640
 
11641
+ static void ggml_rope_cache_init(
11642
+ float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11643
+ float * cache, float sin_sign, float theta_scale
11644
+ ) {
11645
+ float theta = theta_base;
11646
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11647
+ rope_yarn(
11648
+ theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
11649
+ );
11650
+ cache[i0 + 1] *= sin_sign;
11651
+
11652
+ theta *= theta_scale;
11653
+ }
11654
+ }
11655
+
11656
  void ggml_rope_yarn_corr_dims(
11657
  int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
11658
  ) {
 
11735
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11736
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11737
  const int64_t p = pos[i2];
11738
+
11739
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11740
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11741
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11742
+ }
11743
+
11744
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11745
  if (ir++ < ir0) continue;
11746
  if (ir > ir1) break;
 
11774
  }
11775
  } else if (!is_neox) {
11776
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11777
+ const float cos_theta = cache[i0 + 0];
11778
+ const float sin_theta = cache[i0 + 1];
 
 
 
11779
 
11780
  // zeta scaling for xPos only:
11781
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11782
  if (xpos_down) zeta = 1.0f / zeta;
11783
 
 
 
11784
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11785
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11786
 
 
11904
  for (int64_t i3 = 0; i3 < ne3; i3++) {
11905
  for (int64_t i2 = 0; i2 < ne2; i2++) {
11906
  const int64_t p = pos[i2];
11907
+
11908
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11909
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11910
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11911
+ }
11912
+
11913
  for (int64_t i1 = 0; i1 < ne1; i1++) {
11914
  if (ir++ < ir0) continue;
11915
  if (ir > ir1) break;
 
11943
  }
11944
  } else if (!is_neox) {
11945
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11946
+ const float cos_theta = cache[i0 + 0];
11947
+ const float sin_theta = cache[i0 + 1];
 
 
 
 
 
11948
 
11949
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11950
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
16739
  }
16740
  } break;
16741
  case GGML_OP_SOFT_MAX:
16742
+ case GGML_OP_ROPE:
16743
  {
16744
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16745
  } break;