ggerganov commited on
Commit
a570c92
·
unverified ·
1 Parent(s): 46033e6

whisper : add support for new distilled Whisper models (#1424)

Browse files

* whisper : add support for new distilled Whisper models

* whisper : print log when using distilled models

Files changed (1) hide show
  1. whisper.cpp +13 -0
whisper.cpp CHANGED
@@ -3940,6 +3940,7 @@ static void whisper_process_logits(
3940
  // suppress task tokens
3941
  logits[vocab.token_translate] = -INFINITY;
3942
  logits[vocab.token_transcribe] = -INFINITY;
 
3943
 
3944
  if (params.logits_filter_callback) {
3945
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -4558,6 +4559,7 @@ int whisper_full_with_state(
4558
 
4559
  // these tokens determine the task that will be performed
4560
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
 
4561
  if (whisper_is_multilingual(ctx)) {
4562
  const int lang_id = whisper_lang_id(params.language);
4563
  state->lang_id = lang_id;
@@ -4569,6 +4571,17 @@ int whisper_full_with_state(
4569
  }
4570
  }
4571
 
 
 
 
 
 
 
 
 
 
 
 
4572
  int seek = seek_start;
4573
 
4574
  std::vector<whisper_token> prompt;
 
3940
  // suppress task tokens
3941
  logits[vocab.token_translate] = -INFINITY;
3942
  logits[vocab.token_transcribe] = -INFINITY;
3943
+ logits[vocab.token_prev] = -INFINITY;
3944
 
3945
  if (params.logits_filter_callback) {
3946
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
 
4559
 
4560
  // these tokens determine the task that will be performed
4561
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
4562
+
4563
  if (whisper_is_multilingual(ctx)) {
4564
  const int lang_id = whisper_lang_id(params.language);
4565
  state->lang_id = lang_id;
 
4571
  }
4572
  }
4573
 
4574
+ {
4575
+ const bool is_distil = ctx->model.hparams.n_text_layer == 2;
4576
+
4577
+ // distilled models require the "no_timestamps" token
4578
+ // TODO: add input parameter (#1229)
4579
+ if (is_distil) {
4580
+ log("%s: using distilled model - forcing no_timestamps\n", __func__);
4581
+ prompt_init.push_back(whisper_token_not(ctx));
4582
+ }
4583
+ }
4584
+
4585
  int seek = seek_start;
4586
 
4587
  std::vector<whisper_token> prompt;