sachaarbonel commited on
Commit
8e40db9
·
unverified ·
1 Parent(s): 5b0631d

server : add no-speech threshold parameter and functionality (#2654)

Browse files
examples/server/server.cpp CHANGED
@@ -61,6 +61,7 @@ struct whisper_params {
61
  float logprob_thold = -1.00f;
62
  float temperature = 0.00f;
63
  float temperature_inc = 0.20f;
 
64
 
65
  bool debug_mode = false;
66
  bool translate = false;
@@ -137,6 +138,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
137
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
138
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
139
  fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
 
140
  fprintf(stderr, "\n");
141
  }
142
 
@@ -182,6 +184,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
182
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
183
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
184
  else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
 
 
185
  // server params
186
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
187
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@@ -790,6 +794,7 @@ int main(int argc, char ** argv) {
790
  wparams.beam_search.beam_size = params.beam_size;
791
 
792
  wparams.temperature = params.temperature;
 
793
  wparams.temperature_inc = params.temperature_inc;
794
  wparams.entropy_thold = params.entropy_thold;
795
  wparams.logprob_thold = params.logprob_thold;
@@ -942,7 +947,7 @@ int main(int argc, char ** argv) {
942
 
943
  // TODO compression_ratio and no_speech_prob are not implemented yet
944
  // segment["compression_ratio"] = 0;
945
- // segment["no_speech_prob"] = 0;
946
 
947
  jres["segments"].push_back(segment);
948
  }
 
61
  float logprob_thold = -1.00f;
62
  float temperature = 0.00f;
63
  float temperature_inc = 0.20f;
64
+ float no_speech_thold = 0.6f;
65
 
66
  bool debug_mode = false;
67
  bool translate = false;
 
138
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
139
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
140
  fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
141
+ fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
142
  fprintf(stderr, "\n");
143
  }
144
 
 
184
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
185
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
186
  else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
187
+ else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
188
+
189
  // server params
190
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
191
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
 
794
  wparams.beam_search.beam_size = params.beam_size;
795
 
796
  wparams.temperature = params.temperature;
797
+ wparams.no_speech_thold = params.no_speech_thold;
798
  wparams.temperature_inc = params.temperature_inc;
799
  wparams.entropy_thold = params.entropy_thold;
800
  wparams.logprob_thold = params.logprob_thold;
 
947
 
948
  // TODO compression_ratio and no_speech_prob are not implemented yet
949
  // segment["compression_ratio"] = 0;
950
+ segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
951
 
952
  jres["segments"].push_back(segment);
953
  }
include/whisper.h CHANGED
@@ -665,6 +665,8 @@ extern "C" {
665
 
666
  WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
667
 
 
 
668
  #ifdef __cplusplus
669
  }
670
  #endif
 
665
 
666
  WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
667
 
668
+ // Get the no_speech probability for the specified segment
669
+ WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment);
670
  #ifdef __cplusplus
671
  }
672
  #endif
src/whisper.cpp CHANGED
@@ -428,6 +428,7 @@ struct whisper_segment {
428
  int64_t t1;
429
 
430
  std::string text;
 
431
 
432
  std::vector<whisper_token_data> tokens;
433
 
@@ -6147,7 +6148,7 @@ int whisper_full_with_state(
6147
 
6148
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6149
 
6150
- result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
6151
  for (int j = i0; j <= i; j++) {
6152
  result_all.back().tokens.push_back(tokens_cur[j]);
6153
  }
@@ -6192,7 +6193,7 @@ int whisper_full_with_state(
6192
  }
6193
  }
6194
 
6195
- result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
6196
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6197
  result_all.back().tokens.push_back(tokens_cur[j]);
6198
  }
@@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
6459
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6460
  }
6461
 
 
 
 
 
6462
  // =================================================================================================
6463
 
6464
  //
 
428
  int64_t t1;
429
 
430
  std::string text;
431
+ float no_speech_prob;
432
 
433
  std::vector<whisper_token_data> tokens;
434
 
 
6148
 
6149
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6150
 
6151
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6152
  for (int j = i0; j <= i; j++) {
6153
  result_all.back().tokens.push_back(tokens_cur[j]);
6154
  }
 
6193
  }
6194
  }
6195
 
6196
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6197
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6198
  result_all.back().tokens.push_back(tokens_cur[j]);
6199
  }
 
6460
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6461
  }
6462
 
6463
+ float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
6464
+ return ctx->state->result_all[i_segment].no_speech_prob;
6465
+ }
6466
+
6467
  // =================================================================================================
6468
 
6469
  //