ggerganov commited on
Commit
c46c0dc
·
unverified ·
1 Parent(s): ca49ab0

Improve decoding (#291)

Browse files

* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders

.gitignore CHANGED
@@ -8,6 +8,7 @@ build/
8
  build-em/
9
  build-debug/
10
  build-release/
 
11
  build-sanitize-addr/
12
  build-sanitize-thread/
13
 
@@ -18,6 +19,7 @@ build-sanitize-thread/
18
  /bench
19
 
20
  sync.sh
 
21
  libwhisper.so
22
  compile_commands.json
23
 
 
8
  build-em/
9
  build-debug/
10
  build-release/
11
+ build-static/
12
  build-sanitize-addr/
13
  build-sanitize-thread/
14
 
 
19
  /bench
20
 
21
  sync.sh
22
+ libwhisper.a
23
  libwhisper.so
24
  compile_commands.json
25
 
README.md CHANGED
@@ -212,17 +212,7 @@ make large
212
  ## Limitations
213
 
214
  - Inference only
215
- - No GPU support
216
- - Very basic greedy sampling scheme - always pick up the token with highest probability.
217
- This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
218
- from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
219
- to run the python code with the following parameters:
220
-
221
- ```
222
- whisper --best_of None --beam_size None ...
223
- ```
224
-
225
- In the future, `whisper.cpp` will support more sampling strategies.
226
 
227
  ## Another example
228
 
 
212
  ## Limitations
213
 
214
  - Inference only
215
+ - No GPU support (yet)
 
 
 
 
 
 
 
 
 
 
216
 
217
  ## Another example
218
 
examples/command/command.cpp CHANGED
@@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
671
  break;
672
  }
673
 
674
- const auto * probs = whisper_get_probs(ctx);
675
- std::vector<std::pair<float, int>> probs_id;
676
-
677
- double psum = 0.0;
678
- for (int i = 0; i < (int) allowed_commands.size(); ++i) {
679
- probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
680
- for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
681
- probs_id.back().first += probs[allowed_tokens[i][j]];
682
- }
683
- probs_id.back().first /= allowed_tokens[i].size();
684
- psum += probs_id.back().first;
685
- }
686
 
687
- // normalize
688
- for (auto & p : probs_id) {
689
- p.first /= psum;
690
- }
691
 
692
- // sort descending
693
- {
694
- using pair_type = decltype(probs_id)::value_type;
695
- std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
696
- return a.first > b.first;
697
- });
698
- }
699
 
700
- // print the commands and the respective probabilities
701
- {
702
- fprintf(stdout, "\n");
703
- for (const auto & cmd : probs_id) {
704
- fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
705
- for (int token : allowed_tokens[cmd.second]) {
706
- fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
 
707
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  fprintf(stdout, "\n");
 
 
 
 
 
 
 
709
  }
710
- }
711
 
712
- // best command
713
- {
714
- const auto t_end = std::chrono::high_resolution_clock::now();
715
 
716
- const float prob = probs_id[0].first;
717
- const int index = probs_id[0].second;
718
 
719
- fprintf(stdout, "\n");
720
- fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
721
- "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
722
- (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
723
- fprintf(stdout, "\n");
 
724
  }
725
 
726
  audio.clear();
 
671
  break;
672
  }
673
 
674
+ // estimate command probability
675
+ // NOTE: not optimal
676
+ {
677
+ const auto * logits = whisper_get_logits(ctx);
 
 
 
 
 
 
 
 
678
 
679
+ std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
 
 
 
680
 
681
+ // compute probs from logits via softmax
682
+ {
683
+ float max = -1e9;
684
+ for (int i = 0; i < (int) probs.size(); ++i) {
685
+ max = std::max(max, logits[i]);
686
+ }
 
687
 
688
+ float sum = 0.0f;
689
+ for (int i = 0; i < (int) probs.size(); ++i) {
690
+ probs[i] = expf(logits[i] - max);
691
+ sum += probs[i];
692
+ }
693
+
694
+ for (int i = 0; i < (int) probs.size(); ++i) {
695
+ probs[i] /= sum;
696
  }
697
+ }
698
+
699
+ std::vector<std::pair<float, int>> probs_id;
700
+
701
+ double psum = 0.0;
702
+ for (int i = 0; i < (int) allowed_commands.size(); ++i) {
703
+ probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
704
+ for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
705
+ probs_id.back().first += probs[allowed_tokens[i][j]];
706
+ }
707
+ probs_id.back().first /= allowed_tokens[i].size();
708
+ psum += probs_id.back().first;
709
+ }
710
+
711
+ // normalize
712
+ for (auto & p : probs_id) {
713
+ p.first /= psum;
714
+ }
715
+
716
+ // sort descending
717
+ {
718
+ using pair_type = decltype(probs_id)::value_type;
719
+ std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
720
+ return a.first > b.first;
721
+ });
722
+ }
723
+
724
+ // print the commands and the respective probabilities
725
+ {
726
  fprintf(stdout, "\n");
727
+ for (const auto & cmd : probs_id) {
728
+ fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
729
+ for (int token : allowed_tokens[cmd.second]) {
730
+ fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
731
+ }
732
+ fprintf(stdout, "\n");
733
+ }
734
  }
 
735
 
736
+ // best command
737
+ {
738
+ const auto t_end = std::chrono::high_resolution_clock::now();
739
 
740
+ const float prob = probs_id[0].first;
741
+ const int index = probs_id[0].second;
742
 
743
+ fprintf(stdout, "\n");
744
+ fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
745
+ "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
746
+ (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
747
+ fprintf(stdout, "\n");
748
+ }
749
  }
750
 
751
  audio.clear();
examples/main/main.cpp CHANGED
@@ -59,8 +59,12 @@ struct whisper_params {
59
  int32_t duration_ms = 0;
60
  int32_t max_context = -1;
61
  int32_t max_len = 0;
 
 
62
 
63
- float word_thold = 0.01f;
 
 
64
 
65
  bool speed_up = false;
66
  bool translate = false;
@@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
104
  else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
105
  else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
106
  else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
 
 
107
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
 
 
108
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
109
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
110
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
136
  fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
137
  fprintf(stderr, "\n");
138
  fprintf(stderr, "options:\n");
139
- fprintf(stderr, " -h, --help [default] show this help message and exit\n");
140
- fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
141
- fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
142
- fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
143
- fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
144
- fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
145
- fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
146
- fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
147
- fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
148
- fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
149
- fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
150
- fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
151
- fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
152
- fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
153
- fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
154
- fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
155
- fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
156
- fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
157
- fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
158
- fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
159
- fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
160
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
161
- fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
162
- fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
163
- fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
 
 
 
 
164
  fprintf(stderr, "\n");
165
  }
166
 
@@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
235
  const char * text = whisper_full_get_token_text(ctx, i, j);
236
  const float p = whisper_full_get_token_p (ctx, i, j);
237
 
238
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
239
 
240
  printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
241
  }
@@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
331
  const int n_segments = whisper_full_n_segments(ctx);
332
  for (int i = 0; i < n_segments; ++i) {
333
  const char * text = whisper_full_get_segment_text(ctx, i);
334
- if (text[0] == ' ')
335
- text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
 
336
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
337
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
338
- //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
339
- fout << 10 * t0 << ", "
340
- << 10 * t1 << ", \""
341
- << text << "\"\n";
342
  }
343
 
344
  return true;
345
  }
346
 
347
-
348
  // karaoke video generation
349
  // outputs a bash script that uses ffmpeg to generate a video with the subtitles
350
  // TODO: font parameter adjustments
@@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
620
  {
621
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
622
 
 
 
623
  wparams.print_realtime = false;
624
  wparams.print_progress = params.print_progress;
625
  wparams.print_timestamps = !params.no_timestamps;
@@ -633,12 +646,18 @@ int main(int argc, char ** argv) {
633
 
634
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
635
  wparams.thold_pt = params.word_thold;
 
 
636
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
637
 
638
  wparams.speed_up = params.speed_up;
639
 
640
- wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
641
- wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
 
 
 
 
642
 
643
  whisper_print_user_data user_data = { &params, &pcmf32s };
644
 
 
59
  int32_t duration_ms = 0;
60
  int32_t max_context = -1;
61
  int32_t max_len = 0;
62
+ int32_t best_of = 5;
63
+ int32_t beam_size = -1;
64
 
65
+ float word_thold = 0.01f;
66
+ float entropy_thold = 2.4f;
67
+ float logprob_thold = -1.0f;
68
 
69
  bool speed_up = false;
70
  bool translate = false;
 
108
  else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
109
  else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
110
  else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
111
+ else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
112
+ else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
113
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
114
+ else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
115
+ else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
116
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
117
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
118
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
 
144
  fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
145
  fprintf(stderr, "\n");
146
  fprintf(stderr, "options:\n");
147
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
148
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
149
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
150
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
151
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
152
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
153
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
154
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
155
+ fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
156
+ fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
157
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
158
+ fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
159
+ fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
160
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
161
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
162
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
163
+ fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
164
+ fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
165
+ fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
166
+ fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
167
+ fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
168
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
169
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
170
+ fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
171
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
172
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
173
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
174
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
175
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
176
  fprintf(stderr, "\n");
177
  }
178
 
 
247
  const char * text = whisper_full_get_token_text(ctx, i, j);
248
  const float p = whisper_full_get_token_p (ctx, i, j);
249
 
250
+ const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
251
 
252
  printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
253
  }
 
343
  const int n_segments = whisper_full_n_segments(ctx);
344
  for (int i = 0; i < n_segments; ++i) {
345
  const char * text = whisper_full_get_segment_text(ctx, i);
346
+ if (text[0] == ' ') {
347
+ text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
348
+ }
349
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
350
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
351
+
352
+ //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
353
+ fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
 
354
  }
355
 
356
  return true;
357
  }
358
 
 
359
  // karaoke video generation
360
  // outputs a bash script that uses ffmpeg to generate a video with the subtitles
361
  // TODO: font parameter adjustments
 
631
  {
632
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
633
 
634
+ wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
635
+
636
  wparams.print_realtime = false;
637
  wparams.print_progress = params.print_progress;
638
  wparams.print_timestamps = !params.no_timestamps;
 
646
 
647
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
648
  wparams.thold_pt = params.word_thold;
649
+ wparams.entropy_thold = params.entropy_thold;
650
+ wparams.logprob_thold = params.logprob_thold;
651
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
652
 
653
  wparams.speed_up = params.speed_up;
654
 
655
+ wparams.greedy.best_of = params.best_of;
656
+ wparams.beam_search.beam_size = params.beam_size;
657
+ wparams.temperature_inc = -1;
658
+
659
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
660
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
661
 
662
  whisper_print_user_data user_data = { &params, &pcmf32s };
663
 
examples/stream.wasm/emscripten.cpp CHANGED
@@ -49,6 +49,9 @@ void stream_main(size_t index) {
49
  wparams.max_tokens = 32;
50
  wparams.audio_ctx = 768; // partial encoder context for better performance
51
 
 
 
 
52
  wparams.language = "en";
53
 
54
  printf("stream: using %d threads\n", wparams.n_threads);
 
49
  wparams.max_tokens = 32;
50
  wparams.audio_ctx = 768; // partial encoder context for better performance
51
 
52
+ // disable temperature fallback
53
+ wparams.temperature_inc = -1.0f;
54
+
55
  wparams.language = "en";
56
 
57
  printf("stream: using %d threads\n", wparams.n_threads);
examples/stream/stream.cpp CHANGED
@@ -615,6 +615,9 @@ int main(int argc, char ** argv) {
615
  wparams.audio_ctx = params.audio_ctx;
616
  wparams.speed_up = params.speed_up;
617
 
 
 
 
618
  wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
619
  wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
620
 
 
615
  wparams.audio_ctx = params.audio_ctx;
616
  wparams.speed_up = params.speed_up;
617
 
618
+ // disable temperature fallback
619
+ wparams.temperature_inc = -1.0f;
620
+
621
  wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
622
  wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
623
 
examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore ADDED
File without changes
examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore ADDED
File without changes
examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj CHANGED
@@ -35,10 +35,10 @@
35
  0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
36
  0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
37
  0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
38
- 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = "<group>"; };
39
- 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = "<group>"; };
40
- 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
41
- 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
42
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
43
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
44
  /* End PBXFileReference section */
@@ -129,7 +129,8 @@
129
  0AAC5DC729539EB0003032C3 /* whisper.cpp */,
130
  0AAC5DC829539EB0003032C3 /* whisper.h */,
131
  );
132
- path = whisper.cpp;
 
133
  sourceTree = "<group>";
134
  };
135
  0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {
 
35
  0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
36
  0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
37
  0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
38
+ 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
39
+ 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
40
+ 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
41
+ 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
42
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
43
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
44
  /* End PBXFileReference section */
 
129
  0AAC5DC729539EB0003032C3 /* whisper.cpp */,
130
  0AAC5DC829539EB0003032C3 /* whisper.h */,
131
  );
132
+ name = whisper.cpp;
133
+ path = ../..;
134
  sourceTree = "<group>";
135
  };
136
  0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {
whisper.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
whisper.h CHANGED
@@ -74,6 +74,7 @@ extern "C" {
74
  whisper_token tid; // forced timestamp token id
75
 
76
  float p; // probability of the token
 
77
  float pt; // probability of the timestamp token
78
  float ptsum; // sum of probabilities of all timestamp tokens
79
 
@@ -136,6 +137,7 @@ extern "C" {
136
  // tokens + n_tokens is the provided context for the decoder.
137
  // n_past is the number of tokens to use from previous decoder calls.
138
  // Returns 0 on success
 
139
  WHISPER_API int whisper_decode(
140
  struct whisper_context * ctx,
141
  const whisper_token * tokens,
@@ -143,14 +145,6 @@ extern "C" {
143
  int n_past,
144
  int n_threads);
145
 
146
- // Token sampling methods.
147
- // These are provided for convenience and can be used after each call to whisper_decode().
148
- // You can also implement your own sampling method using the whisper_get_probs() function.
149
- // whisper_sample_best() returns the token with the highest probability
150
- // whisper_sample_timestamp() returns the most probable timestamp token
151
- WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
152
- WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
153
-
154
  // Convert the provided text into tokens.
155
  // The tokens pointer must be large enough to hold the resulting tokens.
156
  // Returns the number of tokens on success, no more than n_max_tokens
@@ -192,8 +186,11 @@ extern "C" {
192
  WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
193
  WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
194
 
195
- // The probabilities for the next token
196
- WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
 
 
 
197
 
198
  // Token Id -> String. Uses the vocabulary in the provided context
199
  WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@@ -222,8 +219,8 @@ extern "C" {
222
 
223
  // Available sampling strategies
224
  enum whisper_sampling_strategy {
225
- WHISPER_SAMPLING_GREEDY, // Always select the most probable token
226
- WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
227
  };
228
 
229
  // Text segment callback
@@ -243,17 +240,17 @@ extern "C" {
243
  enum whisper_sampling_strategy strategy;
244
 
245
  int n_threads;
246
- int n_max_text_ctx;
247
  int offset_ms; // start offset in ms
248
  int duration_ms; // audio duration to process in ms
249
 
250
  bool translate;
251
- bool no_context;
252
  bool single_segment; // force single segment output (useful for streaming)
253
- bool print_special;
254
- bool print_progress;
255
- bool print_realtime;
256
- bool print_timestamps;
257
 
258
  // [EXPERIMENTAL] token-level timestamps
259
  bool token_timestamps; // enable token-level timestamps
@@ -263,10 +260,11 @@ extern "C" {
263
  int max_tokens; // max tokens per segment (0 = no limit)
264
 
265
  // [EXPERIMENTAL] speed-up techniques
 
266
  bool speed_up; // speed-up the audio by 2x using Phase Vocoder
267
  int audio_ctx; // overwrite the audio context size (0 = use default)
268
 
269
- // tokens to provide the whisper model as initial prompt
270
  // these are prepended to any existing text context from a previous call
271
  const whisper_token * prompt_tokens;
272
  int prompt_n_tokens;
@@ -274,19 +272,35 @@ extern "C" {
274
  // for auto-detection, set to nullptr, "" or "auto"
275
  const char * language;
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  struct {
278
- int n_past;
279
  } greedy;
280
 
281
  struct {
282
- int n_past;
283
- int beam_width;
284
- int n_best;
285
  } beam_search;
286
 
 
287
  whisper_new_segment_callback new_segment_callback;
288
  void * new_segment_callback_user_data;
289
 
 
290
  whisper_encoder_begin_callback encoder_begin_callback;
291
  void * encoder_begin_callback_user_data;
292
  };
 
74
  whisper_token tid; // forced timestamp token id
75
 
76
  float p; // probability of the token
77
+ float plog; // log probability of the token
78
  float pt; // probability of the timestamp token
79
  float ptsum; // sum of probabilities of all timestamp tokens
80
 
 
137
  // tokens + n_tokens is the provided context for the decoder.
138
  // n_past is the number of tokens to use from previous decoder calls.
139
  // Returns 0 on success
140
+ // TODO: add support for multiple decoders
141
  WHISPER_API int whisper_decode(
142
  struct whisper_context * ctx,
143
  const whisper_token * tokens,
 
145
  int n_past,
146
  int n_threads);
147
 
 
 
 
 
 
 
 
 
148
  // Convert the provided text into tokens.
149
  // The tokens pointer must be large enough to hold the resulting tokens.
150
  // Returns the number of tokens on success, no more than n_max_tokens
 
186
  WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
187
  WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
188
 
189
+ // Token logits obtained from the last call to whisper_decode()
190
+ // The logits for the last token are stored in the last row
191
+ // Rows: n_tokens
192
+ // Cols: n_vocab
193
+ WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
194
 
195
  // Token Id -> String. Uses the vocabulary in the provided context
196
  WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
 
219
 
220
  // Available sampling strategies
221
  enum whisper_sampling_strategy {
222
+ WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
223
+ WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
224
  };
225
 
226
  // Text segment callback
 
240
  enum whisper_sampling_strategy strategy;
241
 
242
  int n_threads;
243
+ int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
244
  int offset_ms; // start offset in ms
245
  int duration_ms; // audio duration to process in ms
246
 
247
  bool translate;
248
+ bool no_context; // do not use initial prompt for the decoder (if any)
249
  bool single_segment; // force single segment output (useful for streaming)
250
+ bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
251
+ bool print_progress; // print progress information
252
+ bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
253
+ bool print_timestamps; // print timestamps for each text segment when printing realtime
254
 
255
  // [EXPERIMENTAL] token-level timestamps
256
  bool token_timestamps; // enable token-level timestamps
 
260
  int max_tokens; // max tokens per segment (0 = no limit)
261
 
262
  // [EXPERIMENTAL] speed-up techniques
263
+ // note: these can significantly reduce the quality of the output
264
  bool speed_up; // speed-up the audio by 2x using Phase Vocoder
265
  int audio_ctx; // overwrite the audio context size (0 = use default)
266
 
267
+ // tokens to provide to the whisper decoder as initial prompt
268
  // these are prepended to any existing text context from a previous call
269
  const whisper_token * prompt_tokens;
270
  int prompt_n_tokens;
 
272
  // for auto-detection, set to nullptr, "" or "auto"
273
  const char * language;
274
 
275
+ // common decoding parameters:
276
+ bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
277
+
278
+ float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
279
+ float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
280
+ float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
281
+
282
+ // fallback parameters
283
+ // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
284
+ float temperature_inc;
285
+ float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
286
+ float logprob_thold;
287
+ float no_speech_thold; // TODO: not implemented
288
+
289
  struct {
290
+ int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
291
  } greedy;
292
 
293
  struct {
294
+ int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
295
+
296
+ float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
297
  } beam_search;
298
 
299
+ // called for every newly generated text segment
300
  whisper_new_segment_callback new_segment_callback;
301
  void * new_segment_callback_user_data;
302
 
303
+ // called each time before the encoder starts
304
  whisper_encoder_begin_callback encoder_begin_callback;
305
  void * encoder_begin_callback_user_data;
306
  };