Hrishikesh Barman commited on
Commit
cde9ea7
·
unverified ·
1 Parent(s): fde2cd7

whisper : move progress calculation out of whisper.cpp (#1081)

Browse files

Current `progress_step` was hardcoded into whisper.cpp, this resulted in
bindings having to access progress only at that step even if progress
callback was being called at every iteration.

With this change we get greater granularity progress reporting from
whisper.cpp and bindings/implementations can define their own progress step.

Files changed (2) hide show
  1. examples/main/main.cpp +16 -1
  2. whisper.cpp +1 -10
examples/main/main.cpp CHANGED
@@ -59,6 +59,7 @@ struct whisper_params {
59
  int32_t offset_t_ms = 0;
60
  int32_t offset_n = 0;
61
  int32_t duration_ms = 0;
 
62
  int32_t max_context = -1;
63
  int32_t max_len = 0;
64
  int32_t best_of = 2;
@@ -218,6 +219,7 @@ struct whisper_print_user_data {
218
  const whisper_params * params;
219
 
220
  const std::vector<std::vector<float>> * pcmf32s;
 
221
  };
222
 
223
  std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
@@ -252,6 +254,14 @@ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s
252
 
253
  return speaker;
254
  }
 
 
 
 
 
 
 
 
255
 
256
  void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
257
  const auto & params = *((whisper_print_user_data *) user_data)->params;
@@ -895,7 +905,7 @@ int main(int argc, char ** argv) {
895
  wparams.entropy_thold = params.entropy_thold;
896
  wparams.logprob_thold = params.logprob_thold;
897
 
898
- whisper_print_user_data user_data = { &params, &pcmf32s };
899
 
900
  // this callback is called on each new segment
901
  if (!wparams.print_realtime) {
@@ -903,6 +913,11 @@ int main(int argc, char ** argv) {
903
  wparams.new_segment_callback_user_data = &user_data;
904
  }
905
 
 
 
 
 
 
906
  // example for abort mechanism
907
  // in this example, we do not abort the processing, but we could if the flag is set to true
908
  // the callback is called before every encoder run - if it returns false, the processing is aborted
 
59
  int32_t offset_t_ms = 0;
60
  int32_t offset_n = 0;
61
  int32_t duration_ms = 0;
62
+ int32_t progress_step = 5;
63
  int32_t max_context = -1;
64
  int32_t max_len = 0;
65
  int32_t best_of = 2;
 
219
  const whisper_params * params;
220
 
221
  const std::vector<std::vector<float>> * pcmf32s;
222
+ int progress_prev;
223
  };
224
 
225
  std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
 
254
 
255
  return speaker;
256
  }
257
+ void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
258
+ int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
259
+ int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
260
+ if (progress >= *progress_prev + progress_step) {
261
+ *progress_prev += progress_step;
262
+ fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
263
+ }
264
+ }
265
 
266
  void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
267
  const auto & params = *((whisper_print_user_data *) user_data)->params;
 
905
  wparams.entropy_thold = params.entropy_thold;
906
  wparams.logprob_thold = params.logprob_thold;
907
 
908
+ whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
909
 
910
  // this callback is called on each new segment
911
  if (!wparams.print_realtime) {
 
913
  wparams.new_segment_callback_user_data = &user_data;
914
  }
915
 
916
+ if (wparams.print_progress) {
917
+ wparams.progress_callback = whisper_print_progress_callback;
918
+ wparams.progress_callback_user_data = &user_data;
919
+ }
920
+
921
  // example for abort mechanism
922
  // in this example, we do not abort the processing, but we could if the flag is set to true
923
  // the callback is called before every encoder run - if it returns false, the processing is aborted
whisper.cpp CHANGED
@@ -4163,9 +4163,6 @@ int whisper_full_with_state(
4163
  }
4164
  }
4165
 
4166
- int progress_prev = 0;
4167
- int progress_step = 5;
4168
-
4169
  int seek = seek_start;
4170
 
4171
  std::vector<whisper_token> prompt;
@@ -4193,15 +4190,9 @@ int whisper_full_with_state(
4193
  // main loop
4194
  while (true) {
4195
  const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
4196
- while (progress_cur >= progress_prev + progress_step) {
4197
- progress_prev += progress_step;
4198
- if (params.print_progress) {
4199
- fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
4200
- }
4201
- }
4202
  if (params.progress_callback) {
4203
  params.progress_callback(
4204
- ctx, ctx->state, progress_prev, params.progress_callback_user_data);
4205
  }
4206
 
4207
  // of only 1 second left, then stop
 
4163
  }
4164
  }
4165
 
 
 
 
4166
  int seek = seek_start;
4167
 
4168
  std::vector<whisper_token> prompt;
 
4190
  // main loop
4191
  while (true) {
4192
  const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
 
 
 
 
 
 
4193
  if (params.progress_callback) {
4194
  params.progress_callback(
4195
+ ctx, ctx->state, progress_cur, params.progress_callback_user_data);
4196
  }
4197
 
4198
  // of only 1 second left, then stop