ggerganov commited on
Commit
a2015c0
·
1 Parent(s): 06b8aa3

stream : partial encoder experiments

Browse files
Files changed (3) hide show
  1. examples/stream/stream.cpp +4 -2
  2. whisper.cpp +47 -20
  3. whisper.h +3 -0
examples/stream/stream.cpp CHANGED
@@ -221,6 +221,7 @@ int main(int argc, char ** argv) {
221
  const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
222
  const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
223
  const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
 
224
 
225
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
226
  std::vector<float> pcmf32_old;
@@ -303,7 +304,7 @@ int main(int argc, char ** argv) {
303
  //const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
304
 
305
  // take up to params.length_ms audio from previous iteration
306
- const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new));
307
 
308
  //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
309
 
@@ -379,7 +380,8 @@ int main(int argc, char ** argv) {
379
  if ((n_iter % n_new_line) == 0) {
380
  printf("\n");
381
 
382
- pcmf32_old.clear();
 
383
  }
384
  }
385
  }
 
221
  const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
222
  const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
223
  const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
224
+ const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE;
225
 
226
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
227
  std::vector<float> pcmf32_old;
 
304
  //const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
305
 
306
  // take up to params.length_ms audio from previous iteration
307
+ const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
308
 
309
  //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
310
 
 
380
  if ((n_iter % n_new_line) == 0) {
381
  printf("\n");
382
 
383
+ // keep part of the audio for next iteration to try to mitigate word boundary issues
384
+ pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
385
  }
386
  }
387
  }
whisper.cpp CHANGED
@@ -613,7 +613,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
613
  const int n_audio_state = hparams.n_audio_state;
614
  const int n_audio_layer = hparams.n_audio_layer;
615
 
616
- const int n_text_ctx = hparams.n_text_ctx;
617
  const int n_text_state = hparams.n_text_state;
618
  const int n_text_layer = hparams.n_text_layer;
619
 
@@ -748,7 +748,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
748
  const int n_audio_state = hparams.n_audio_state;
749
  const int n_audio_layer = hparams.n_audio_layer;
750
 
751
- const int n_text_ctx = hparams.n_text_ctx;
752
  const int n_text_state = hparams.n_text_state;
753
  const int n_text_layer = hparams.n_text_layer;
754
 
@@ -967,13 +967,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
967
 
968
  // key/value memory for the cross-attention layer
969
  {
970
- const int n_audio_ctx = hparams.n_audio_ctx;
971
 
972
  const int n_mem = n_text_layer*n_audio_ctx;
973
  const int n_elements = n_text_state*n_mem;
974
 
975
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
976
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
 
 
977
  }
978
 
979
  const size_t memory_size =
@@ -1076,13 +1079,11 @@ static bool whisper_encode(
1076
  const auto & mel_inp = wctx.mel;
1077
  const auto & hparams = model.hparams;
1078
 
1079
- const int n_ctx = hparams.n_audio_ctx;
1080
  const int n_state = hparams.n_audio_state;
1081
  const int n_head = hparams.n_audio_head;
1082
  const int n_layer = hparams.n_audio_layer;
1083
 
1084
- const int N = n_ctx;
1085
-
1086
  const int n_mels = hparams.n_mels;
1087
  assert(mel_inp.n_mel == n_mels);
1088
 
@@ -1132,7 +1133,24 @@ static bool whisper_encode(
1132
  cur = ggml_gelu(ctx0, cur);
1133
  }
1134
 
1135
- cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1136
 
1137
  struct ggml_tensor * inpL = cur;
1138
 
@@ -1198,14 +1216,14 @@ static bool whisper_encode(
1198
  ggml_permute(ctxL,
1199
  ggml_cpy(ctxL,
1200
  Qcur,
1201
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1202
  0, 2, 1, 3);
1203
 
1204
  struct ggml_tensor * K =
1205
  ggml_permute(ctxL,
1206
  ggml_cpy(ctxL,
1207
  Kcur,
1208
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1209
  0, 2, 1, 3);
1210
 
1211
  struct ggml_tensor * V =
@@ -1213,9 +1231,9 @@ static bool whisper_encode(
1213
  ggml_permute(ctxL,
1214
  ggml_reshape_3d(ctxL,
1215
  Vcur,
1216
- n_state/n_head, n_head, N),
1217
  1, 2, 0, 3),
1218
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1219
  );
1220
 
1221
  struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1224,14 +1242,14 @@ static bool whisper_encode(
1224
  ggml_permute(ctxL,
1225
  ggml_cpy(ctxL,
1226
  Qcur,
1227
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1228
  0, 2, 1, 3);
1229
 
1230
  struct ggml_tensor * K =
1231
  ggml_permute(ctxL,
1232
  ggml_cpy(ctxL,
1233
  Kcur,
1234
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1235
  0, 2, 1, 3);
1236
 
1237
  // K * Q
@@ -1249,7 +1267,7 @@ static bool whisper_encode(
1249
  // ggml_permute(ctxL,
1250
  // ggml_cpy(ctxL,
1251
  // Vcur,
1252
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1253
  // 1, 2, 0, 3);
1254
 
1255
  //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1259,9 +1277,9 @@ static bool whisper_encode(
1259
  ggml_permute(ctxL,
1260
  ggml_reshape_3d(ctxL,
1261
  Vcur,
1262
- n_state/n_head, n_head, N),
1263
  0, 2, 1, 3),
1264
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1265
  );
1266
 
1267
  struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1271,7 +1289,7 @@ static bool whisper_encode(
1271
 
1272
  cur = ggml_cpy(ctxL,
1273
  KQV_merged,
1274
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1275
  }
1276
 
1277
  // projection
@@ -1425,6 +1443,8 @@ static bool whisper_encode(
1425
  Vcross),
1426
  Vcross);
1427
 
 
 
1428
  struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1429
  struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1430
 
@@ -1474,7 +1494,8 @@ static bool whisper_decode(
1474
  const int n_layer = hparams.n_text_layer;
1475
 
1476
  const int N = n_tokens;
1477
- const int M = hparams.n_audio_ctx;
 
1478
 
1479
  struct ggml_init_params params = {
1480
  .mem_size = wctx.buf_compute.size(),
@@ -2662,7 +2683,7 @@ int whisper_full(
2662
  //}
2663
 
2664
  // end of text token
2665
- if (token.id == whisper_token_eot(ctx)) {
2666
  if (result_len == 0) {
2667
  if (seek + seek_delta + 100 >= seek_end) {
2668
  result_len = i + 1;
@@ -2671,6 +2692,12 @@ int whisper_full(
2671
  fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
2672
  }
2673
  }
 
 
 
 
 
 
2674
  break;
2675
  }
2676
 
@@ -2850,7 +2877,7 @@ int whisper_full_parallel(
2850
 
2851
  // key/value memory for the cross-attention layer
2852
  {
2853
- const int n_audio_ctx = hparams.n_audio_ctx;
2854
 
2855
  const int n_mem = n_text_layer*n_audio_ctx;
2856
  const int n_elements = n_text_state*n_mem;
 
613
  const int n_audio_state = hparams.n_audio_state;
614
  const int n_audio_layer = hparams.n_audio_layer;
615
 
616
+ const int n_text_ctx = hparams.n_text_ctx;
617
  const int n_text_state = hparams.n_text_state;
618
  const int n_text_layer = hparams.n_text_layer;
619
 
 
748
  const int n_audio_state = hparams.n_audio_state;
749
  const int n_audio_layer = hparams.n_audio_layer;
750
 
751
+ const int n_text_ctx = hparams.n_text_ctx;
752
  const int n_text_state = hparams.n_text_state;
753
  const int n_text_layer = hparams.n_text_layer;
754
 
 
967
 
968
  // key/value memory for the cross-attention layer
969
  {
970
+ const int n_audio_ctx = hparams.n_audio_ctx;
971
 
972
  const int n_mem = n_text_layer*n_audio_ctx;
973
  const int n_elements = n_text_state*n_mem;
974
 
975
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
976
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
977
+
978
+ //memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
979
+ //memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
980
  }
981
 
982
  const size_t memory_size =
 
1079
  const auto & mel_inp = wctx.mel;
1080
  const auto & hparams = model.hparams;
1081
 
1082
+ const int n_ctx = WHISPER_EXPERIMENT_AUDIO_CTX;
1083
  const int n_state = hparams.n_audio_state;
1084
  const int n_head = hparams.n_audio_head;
1085
  const int n_layer = hparams.n_audio_layer;
1086
 
 
 
1087
  const int n_mels = hparams.n_mels;
1088
  assert(mel_inp.n_mel == n_mels);
1089
 
 
1133
  cur = ggml_gelu(ctx0, cur);
1134
  }
1135
 
1136
+ //static int iter = -1;
1137
+ //const int n_iter = 1500/n_ctx;
1138
+
1139
+ //iter = (iter + 1) % n_iter;
1140
+
1141
+ //if (iter == 0) {
1142
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1143
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1144
+ //}
1145
+
1146
+ static int iter = 0;
1147
+
1148
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1149
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1150
+
1151
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1152
+
1153
+ cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1154
 
1155
  struct ggml_tensor * inpL = cur;
1156
 
 
1216
  ggml_permute(ctxL,
1217
  ggml_cpy(ctxL,
1218
  Qcur,
1219
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1220
  0, 2, 1, 3);
1221
 
1222
  struct ggml_tensor * K =
1223
  ggml_permute(ctxL,
1224
  ggml_cpy(ctxL,
1225
  Kcur,
1226
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1227
  0, 2, 1, 3);
1228
 
1229
  struct ggml_tensor * V =
 
1231
  ggml_permute(ctxL,
1232
  ggml_reshape_3d(ctxL,
1233
  Vcur,
1234
+ n_state/n_head, n_head, n_ctx),
1235
  1, 2, 0, 3),
1236
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
1237
  );
1238
 
1239
  struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
 
1242
  ggml_permute(ctxL,
1243
  ggml_cpy(ctxL,
1244
  Qcur,
1245
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1246
  0, 2, 1, 3);
1247
 
1248
  struct ggml_tensor * K =
1249
  ggml_permute(ctxL,
1250
  ggml_cpy(ctxL,
1251
  Kcur,
1252
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1253
  0, 2, 1, 3);
1254
 
1255
  // K * Q
 
1267
  // ggml_permute(ctxL,
1268
  // ggml_cpy(ctxL,
1269
  // Vcur,
1270
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1271
  // 1, 2, 0, 3);
1272
 
1273
  //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
 
1277
  ggml_permute(ctxL,
1278
  ggml_reshape_3d(ctxL,
1279
  Vcur,
1280
+ n_state/n_head, n_head, n_ctx),
1281
  0, 2, 1, 3),
1282
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
1283
  );
1284
 
1285
  struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
 
1289
 
1290
  cur = ggml_cpy(ctxL,
1291
  KQV_merged,
1292
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
1293
  }
1294
 
1295
  // projection
 
1443
  Vcross),
1444
  Vcross);
1445
 
1446
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1447
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1448
  struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1449
  struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1450
 
 
1494
  const int n_layer = hparams.n_text_layer;
1495
 
1496
  const int N = n_tokens;
1497
+ //const int M = hparams.n_audio_ctx;
1498
+ const int M = WHISPER_EXPERIMENT_AUDIO_CTX;
1499
 
1500
  struct ggml_init_params params = {
1501
  .mem_size = wctx.buf_compute.size(),
 
2683
  //}
2684
 
2685
  // end of text token
2686
+ if (token.id == whisper_token_eot(ctx) || (i > WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT)) {
2687
  if (result_len == 0) {
2688
  if (seek + seek_delta + 100 >= seek_end) {
2689
  result_len = i + 1;
 
2692
  fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
2693
  }
2694
  }
2695
+
2696
+ // TODO: TMP TO MAKE STREAM WORK ON RPI4 ===
2697
+ result_len = i + 1;
2698
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
2699
+ // =========================================
2700
+
2701
  break;
2702
  }
2703
 
 
2877
 
2878
  // key/value memory for the cross-attention layer
2879
  {
2880
+ const int n_audio_ctx = hparams.n_audio_ctx;
2881
 
2882
  const int n_mem = n_text_layer*n_audio_ctx;
2883
  const int n_elements = n_text_state*n_mem;
whisper.h CHANGED
@@ -24,6 +24,9 @@
24
  #define WHISPER_HOP_LENGTH 160
25
  #define WHISPER_CHUNK_SIZE 30
26
 
 
 
 
27
  #ifdef __cplusplus
28
  extern "C" {
29
  #endif
 
24
  #define WHISPER_HOP_LENGTH 160
25
  #define WHISPER_CHUNK_SIZE 30
26
 
27
+ #define WHISPER_EXPERIMENT_AUDIO_CTX 512
28
+ #define WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT 32
29
+
30
  #ifdef __cplusplus
31
  extern "C" {
32
  #endif