ggerganov commited on
Commit
8311a60
·
unverified ·
1 Parent(s): f527611

Try to improve the token sampling strategy (#193)

Browse files

* whisper : try to improve the token sampling strategy

- Add the "max_initial_timestaamp" token logic from OpenAI
- Disallow sampling timestamps that are in the past

* whisper : fix the max initial timestamp logic + fallback decoding

Files changed (2) hide show
  1. whisper.cpp +45 -52
  2. whisper.h +1 -1
whisper.cpp CHANGED
@@ -1846,7 +1846,9 @@ static bool whisper_decode(
1846
  // the most basic sampling scheme - select the top token
1847
  static whisper_token_data whisper_sample_best(
1848
  const whisper_vocab & vocab,
1849
- const float * probs) {
 
 
1850
  whisper_token_data result = {
1851
  0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
1852
  };
@@ -1869,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
1869
  max_tx = std::max(max_tx, probs_id[i].first);
1870
  }
1871
 
1872
- for (int i = vocab.token_beg; i < n_logits; i++) {
 
 
 
 
 
 
 
 
 
 
 
1873
  sum_ts += probs_id[i].first;
1874
  if (probs_id[i].first > max_ts) {
1875
  max_ts = probs_id[i].first;
@@ -1879,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
1879
 
1880
  // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
1881
  // timestamp token
1882
- if (sum_ts > max_tx) {
1883
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
1884
  for (int i = 0; i < vocab.token_beg; i++) {
1885
  probs_id[i].first = -INFINITY;
@@ -1921,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
1921
  return result;
1922
  }
1923
 
1924
- // samples only from the timestamps tokens
1925
- static whisper_vocab::id whisper_sample_timestamp(
1926
- const whisper_vocab & vocab,
1927
- const float * probs) {
1928
- int n_logits = vocab.id_to_token.size();
1929
-
1930
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1931
- probs_id.reserve(n_logits);
1932
-
1933
- for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1934
- probs_id.push_back(std::make_pair(probs[i], i));
1935
- }
1936
-
1937
- const int top_k = 10;
1938
-
1939
- // find the top K tokens
1940
- std::partial_sort(
1941
- probs_id.begin(),
1942
- probs_id.begin() + top_k, probs_id.end(),
1943
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1944
- return a.first > b.first;
1945
- });
1946
-
1947
- probs_id.resize(top_k);
1948
-
1949
- //printf("\n");
1950
- //for (int i = 0; i < (int) probs_id.size(); i++) {
1951
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1952
- //}
1953
-
1954
- return probs_id[0].second;
1955
- }
1956
-
1957
  // 500 -> 00:05.000
1958
  // 6000 -> 01:00.000
1959
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2284,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
2284
  struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
2285
  const int64_t t_start_sample_us = ggml_time_us();
2286
 
2287
- // TODO: simplify
2288
- auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
2289
 
2290
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2291
 
2292
  return res;
2293
  }
2294
 
2295
- whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
2296
  const int64_t t_start_sample_us = ggml_time_us();
2297
 
2298
- // TODO: simplify
2299
- auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
2300
 
2301
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2302
 
@@ -2694,7 +2672,6 @@ int whisper_full(
2694
 
2695
  prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
2696
 
2697
- bool done = false;
2698
  int seek_delta = 100*WHISPER_CHUNK_SIZE;
2699
 
2700
  // print the prompt
@@ -2708,7 +2685,9 @@ int whisper_full(
2708
  int result_len = 0;
2709
  tokens_cur.clear();
2710
 
2711
- for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
 
 
2712
  if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
2713
  fprintf(stderr, "%s: failed to decode\n", __func__);
2714
  return 8;
@@ -2725,15 +2704,19 @@ int whisper_full(
2725
  // feel free to experiment!
2726
  //
2727
  {
2728
- auto token = whisper_sample_best(ctx);
2729
-
2730
- if (i == 0) {
2731
- token.tid = whisper_token_beg(ctx);
2732
- }
2733
 
2734
  // timestamp token - update sliding window
2735
  if (token.id > whisper_token_beg(ctx)) {
2736
- seek_delta = 2*(token.id - whisper_token_beg(ctx));
 
 
 
 
 
 
 
 
2737
  result_len = i + 1;
2738
  }
2739
 
@@ -2752,8 +2735,8 @@ int whisper_full(
2752
  if (seek + seek_delta + 100 >= seek_end) {
2753
  result_len = i + 1;
2754
  } else {
2755
- // TODO: figure out how to resolve this
2756
- fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
2757
  }
2758
  }
2759
 
@@ -2772,11 +2755,21 @@ int whisper_full(
2772
  }
2773
  }
2774
 
2775
- if (done) {
 
 
 
 
2776
  break;
2777
  }
2778
  }
2779
 
 
 
 
 
 
 
2780
  // shrink down to result_len
2781
  tokens_cur.resize(result_len);
2782
 
 
1846
  // the most basic sampling scheme - select the top token
1847
  static whisper_token_data whisper_sample_best(
1848
  const whisper_vocab & vocab,
1849
+ const float * probs,
1850
+ bool force_timestamp,
1851
+ bool is_initial) {
1852
  whisper_token_data result = {
1853
  0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
1854
  };
 
1871
  max_tx = std::max(max_tx, probs_id[i].first);
1872
  }
1873
 
1874
+ const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
1875
+ const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
1876
+
1877
+ // the initial timestamp cannot be larger than 100
1878
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
1879
+ if (is_initial) {
1880
+ for (int i = i0; i < n_logits; ++ i) {
1881
+ probs_id[i].first = -INFINITY;
1882
+ }
1883
+ }
1884
+
1885
+ for (int i = vocab.token_beg; i < i1; i++) {
1886
  sum_ts += probs_id[i].first;
1887
  if (probs_id[i].first > max_ts) {
1888
  max_ts = probs_id[i].first;
 
1892
 
1893
  // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
1894
  // timestamp token
1895
+ if (sum_ts > max_tx || force_timestamp) {
1896
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
1897
  for (int i = 0; i < vocab.token_beg; i++) {
1898
  probs_id[i].first = -INFINITY;
 
1934
  return result;
1935
  }
1936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1937
  // 500 -> 00:05.000
1938
  // 6000 -> 01:00.000
1939
  static std::string to_timestamp(int64_t t, bool comma = false) {
 
2264
  struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
2265
  const int64_t t_start_sample_us = ggml_time_us();
2266
 
2267
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
 
2268
 
2269
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2270
 
2271
  return res;
2272
  }
2273
 
2274
+ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
2275
  const int64_t t_start_sample_us = ggml_time_us();
2276
 
2277
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
 
2278
 
2279
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2280
 
 
2672
 
2673
  prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
2674
 
 
2675
  int seek_delta = 100*WHISPER_CHUNK_SIZE;
2676
 
2677
  // print the prompt
 
2685
  int result_len = 0;
2686
  tokens_cur.clear();
2687
 
2688
+ bool failed = false;
2689
+
2690
+ for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
2691
  if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
2692
  fprintf(stderr, "%s: failed to decode\n", __func__);
2693
  return 8;
 
2704
  // feel free to experiment!
2705
  //
2706
  {
2707
+ const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
 
 
 
 
2708
 
2709
  // timestamp token - update sliding window
2710
  if (token.id > whisper_token_beg(ctx)) {
2711
+ const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
2712
+
2713
+ // do not allow to go back in time
2714
+ if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
2715
+ seek_delta > seek_delta_new && result_len < i) {
2716
+ break;
2717
+ }
2718
+
2719
+ seek_delta = seek_delta_new;
2720
  result_len = i + 1;
2721
  }
2722
 
 
2735
  if (seek + seek_delta + 100 >= seek_end) {
2736
  result_len = i + 1;
2737
  } else {
2738
+ failed = true;
2739
+ break;
2740
  }
2741
  }
2742
 
 
2755
  }
2756
  }
2757
 
2758
+ // sometimes, the decoding can get stuck in a repetition loop
2759
+ // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
2760
+ // the sliding window by 1 second
2761
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
2762
+ failed = true;
2763
  break;
2764
  }
2765
  }
2766
 
2767
+ if (failed) {
2768
+ fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
2769
+ seek += 100;
2770
+ continue;
2771
+ }
2772
+
2773
  // shrink down to result_len
2774
  tokens_cur.resize(result_len);
2775
 
whisper.h CHANGED
@@ -137,7 +137,7 @@ extern "C" {
137
  // whisper_sample_best() returns the token with the highest probability
138
  // whisper_sample_timestamp() returns the most probable timestamp token
139
  WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
140
- WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
141
 
142
  // Return the id of the specified language, returns -1 if not found
143
  WHISPER_API int whisper_lang_id(const char * lang);
 
137
  // whisper_sample_best() returns the token with the highest probability
138
  // whisper_sample_timestamp() returns the most probable timestamp token
139
  WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
140
+ WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
141
 
142
  // Return the id of the specified language, returns -1 if not found
143
  WHISPER_API int whisper_lang_id(const char * lang);