jhenhong commited on
Commit
81dab6f
·
unverified ·
1 Parent(s): 4fc344e

whisper : add initial_prompt param (#645)

Browse files
examples/addon.node/addon.cpp CHANGED
@@ -160,22 +160,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
160
  return 3;
161
  }
162
 
163
- // initial prompt
164
- std::vector<whisper_token> prompt_tokens;
165
-
166
- if (!params.prompt.empty()) {
167
- prompt_tokens.resize(1024);
168
- prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
169
-
170
- fprintf(stderr, "\n");
171
- fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
172
- fprintf(stderr, "initial tokens: [ ");
173
- for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
174
- fprintf(stderr, "%d ", prompt_tokens[i]);
175
- }
176
- fprintf(stderr, "]\n");
177
- }
178
-
179
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
180
  const auto fname_inp = params.fname_inp[f];
181
  const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -243,8 +227,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
243
  wparams.greedy.best_of = params.best_of;
244
  wparams.beam_search.beam_size = params.beam_size;
245
 
246
- wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
247
- wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
248
 
249
  whisper_print_user_data user_data = { &params, &pcmf32s };
250
 
 
160
  return 3;
161
  }
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
164
  const auto fname_inp = params.fname_inp[f];
165
  const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
 
227
  wparams.greedy.best_of = params.best_of;
228
  wparams.beam_search.beam_size = params.beam_size;
229
 
230
+ wparams.initial_prompt = params.prompt.c_str();
 
231
 
232
  whisper_print_user_data user_data = { &params, &pcmf32s };
233
 
examples/main/main.cpp CHANGED
@@ -639,22 +639,6 @@ int main(int argc, char ** argv) {
639
  return 3;
640
  }
641
 
642
- // initial prompt
643
- std::vector<whisper_token> prompt_tokens;
644
-
645
- if (!params.prompt.empty()) {
646
- prompt_tokens.resize(1024);
647
- prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
648
-
649
- fprintf(stderr, "\n");
650
- fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
651
- fprintf(stderr, "initial tokens: [ ");
652
- for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
653
- fprintf(stderr, "%d ", prompt_tokens[i]);
654
- }
655
- fprintf(stderr, "]\n");
656
- }
657
-
658
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
659
  const auto fname_inp = params.fname_inp[f];
660
  const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -718,8 +702,7 @@ int main(int argc, char ** argv) {
718
 
719
  wparams.speed_up = params.speed_up;
720
 
721
- wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
722
- wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
723
 
724
  wparams.greedy.best_of = params.best_of;
725
  wparams.beam_search.beam_size = params.beam_size;
 
639
  return 3;
640
  }
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
643
  const auto fname_inp = params.fname_inp[f];
644
  const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
 
702
 
703
  wparams.speed_up = params.speed_up;
704
 
705
+ wparams.initial_prompt = params.prompt.c_str();
 
706
 
707
  wparams.greedy.best_of = params.best_of;
708
  wparams.beam_search.beam_size = params.beam_size;
whisper.cpp CHANGED
@@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3121
  /*.speed_up =*/ false,
3122
  /*.audio_ctx =*/ 0,
3123
 
 
3124
  /*.prompt_tokens =*/ nullptr,
3125
  /*.prompt_n_tokens =*/ 0,
3126
 
@@ -3793,6 +3794,15 @@ int whisper_full_with_state(
3793
  prompt_past.clear();
3794
  }
3795
 
 
 
 
 
 
 
 
 
 
3796
  // prepend the prompt tokens to the prompt_past
3797
  if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3798
  // parse tokens from the pointer
 
3121
  /*.speed_up =*/ false,
3122
  /*.audio_ctx =*/ 0,
3123
 
3124
+ /*.initial_prompt =*/ nullptr,
3125
  /*.prompt_tokens =*/ nullptr,
3126
  /*.prompt_n_tokens =*/ 0,
3127
 
 
3794
  prompt_past.clear();
3795
  }
3796
 
3797
+ // initial prompt
3798
+ if (!params.prompt_tokens && params.initial_prompt) {
3799
+ std::vector<whisper_token> prompt_tokens;
3800
+ prompt_tokens.resize(1024);
3801
+ prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
3802
+ params.prompt_tokens = prompt_tokens.data();
3803
+ params.prompt_n_tokens = prompt_tokens.size();
3804
+ }
3805
+
3806
  // prepend the prompt tokens to the prompt_past
3807
  if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3808
  // parse tokens from the pointer
whisper.h CHANGED
@@ -356,6 +356,7 @@ extern "C" {
356
 
357
  // tokens to provide to the whisper decoder as initial prompt
358
  // these are prepended to any existing text context from a previous call
 
359
  const whisper_token * prompt_tokens;
360
  int prompt_n_tokens;
361
 
 
356
 
357
  // tokens to provide to the whisper decoder as initial prompt
358
  // these are prepended to any existing text context from a previous call
359
+ const char * initial_prompt;
360
  const whisper_token * prompt_tokens;
361
  int prompt_n_tokens;
362