ggerganov commited on
Commit
d4425e8
·
1 Parent(s): 38f9f3b

whisper : fix bug in prompt processing (close #705)

Browse files
Files changed (2) hide show
  1. examples/main/main.cpp +2 -2
  2. whisper.cpp +23 -21
examples/main/main.cpp CHANGED
@@ -208,8 +208,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
208
 
209
  std::string speaker = "";
210
 
211
- int64_t t0;
212
- int64_t t1;
213
 
214
  // print the last n_new segments
215
  const int s0 = n_segments - n_new;
 
208
 
209
  std::string speaker = "";
210
 
211
+ int64_t t0 = 0;
212
+ int64_t t1 = 0;
213
 
214
  // print the last n_new segments
215
  const int s0 = n_segments - n_new;
whisper.cpp CHANGED
@@ -1260,12 +1260,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1260
  break;
1261
  }
1262
 
1263
- int64_t nelements = 1;
1264
- int64_t ne[3] = { 1, 1, 1 };
1265
  for (int i = 0; i < n_dims; ++i) {
1266
- int32_t ne_cur;
1267
- read_safe(loader, ne_cur);
1268
- ne[i] = ne_cur;
1269
  nelements *= ne[i];
1270
  }
1271
 
@@ -1286,15 +1284,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1286
  }
1287
 
1288
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1289
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lld, %lld, %lld], expected [%lld, %lld, %lld]\n",
1290
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1291
  return false;
1292
  }
1293
 
1294
  const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1295
 
1296
  if (nelements*bpe != ggml_nbytes(tensor)) {
1297
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %llu\n",
1298
  __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1299
  return false;
1300
  }
@@ -3819,22 +3817,26 @@ int whisper_full_with_state(
3819
  prompt_past.clear();
3820
  }
3821
 
3822
- // initial prompt
3823
- if (!params.prompt_tokens && params.initial_prompt) {
3824
  std::vector<whisper_token> prompt_tokens;
3825
- prompt_tokens.resize(1024);
3826
- prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
3827
- params.prompt_tokens = prompt_tokens.data();
3828
- params.prompt_n_tokens = prompt_tokens.size();
3829
- }
3830
 
3831
- // prepend the prompt tokens to the prompt_past
3832
- if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3833
- // parse tokens from the pointer
3834
- for (int i = 0; i < params.prompt_n_tokens; i++) {
3835
- prompt_past.push_back(params.prompt_tokens[i]);
 
 
 
 
 
 
 
 
 
 
3836
  }
3837
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3838
  }
3839
 
3840
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
 
1260
  break;
1261
  }
1262
 
1263
+ int32_t nelements = 1;
1264
+ int32_t ne[3] = { 1, 1, 1 };
1265
  for (int i = 0; i < n_dims; ++i) {
1266
+ read_safe(loader, ne[i]);
 
 
1267
  nelements *= ne[i];
1268
  }
1269
 
 
1284
  }
1285
 
1286
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1287
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1288
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1289
  return false;
1290
  }
1291
 
1292
  const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1293
 
1294
  if (nelements*bpe != ggml_nbytes(tensor)) {
1295
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1296
  __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1297
  return false;
1298
  }
 
3817
  prompt_past.clear();
3818
  }
3819
 
3820
+ // prepare prompt
3821
+ {
3822
  std::vector<whisper_token> prompt_tokens;
 
 
 
 
 
3823
 
3824
+ // initial prompt
3825
+ if (!params.prompt_tokens && params.initial_prompt) {
3826
+ prompt_tokens.resize(1024);
3827
+ prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
3828
+ params.prompt_tokens = prompt_tokens.data();
3829
+ params.prompt_n_tokens = prompt_tokens.size();
3830
+ }
3831
+
3832
+ // prepend the prompt tokens to the prompt_past
3833
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3834
+ // parse tokens from the pointer
3835
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
3836
+ prompt_past.push_back(params.prompt_tokens[i]);
3837
+ }
3838
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3839
  }
 
3840
  }
3841
 
3842
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx