Colin ggerganov commited on
Commit
16846f8
·
unverified ·
1 Parent(s): 1b58b55

main : add diarization support for all current output types (#1031)

Browse files
Files changed (1) hide show
  1. examples/main/main.cpp +118 -50
examples/main/main.cpp CHANGED
@@ -210,6 +210,39 @@ struct whisper_print_user_data {
210
  const std::vector<std::vector<float>> * pcmf32s;
211
  };
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
214
  const auto & params = *((whisper_print_user_data *) user_data)->params;
215
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
239
  }
240
 
241
  if (params.diarize && pcmf32s.size() == 2) {
242
- const int64_t n_samples = pcmf32s[0].size();
243
-
244
- const int64_t is0 = timestamp_to_sample(t0, n_samples);
245
- const int64_t is1 = timestamp_to_sample(t1, n_samples);
246
-
247
- double energy0 = 0.0f;
248
- double energy1 = 0.0f;
249
-
250
- for (int64_t j = is0; j < is1; j++) {
251
- energy0 += fabs(pcmf32s[0][j]);
252
- energy1 += fabs(pcmf32s[1][j]);
253
- }
254
-
255
- if (energy0 > 1.1*energy1) {
256
- speaker = "(speaker 0)";
257
- } else if (energy1 > 1.1*energy0) {
258
- speaker = "(speaker 1)";
259
- } else {
260
- speaker = "(speaker ?)";
261
- }
262
-
263
- //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
264
  }
265
 
266
  if (params.print_colors) {
@@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
294
  }
295
  }
296
 
297
- bool output_txt(struct whisper_context * ctx, const char * fname) {
298
  std::ofstream fout(fname);
299
  if (!fout.is_open()) {
300
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
306
  const int n_segments = whisper_full_n_segments(ctx);
307
  for (int i = 0; i < n_segments; ++i) {
308
  const char * text = whisper_full_get_segment_text(ctx, i);
309
- fout << text << "\n";
 
 
 
 
 
 
 
 
 
310
  }
311
 
312
  return true;
313
  }
314
 
315
- bool output_vtt(struct whisper_context * ctx, const char * fname) {
316
  std::ofstream fout(fname);
317
  if (!fout.is_open()) {
318
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
328
  const char * text = whisper_full_get_segment_text(ctx, i);
329
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
330
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
 
 
 
 
 
 
 
331
 
332
  fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
333
- fout << text << "\n\n";
334
  }
335
 
336
  return true;
337
  }
338
 
339
- bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
340
  std::ofstream fout(fname);
341
  if (!fout.is_open()) {
342
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
350
  const char * text = whisper_full_get_segment_text(ctx, i);
351
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
352
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
 
 
 
 
 
353
 
354
  fout << i + 1 + params.offset_n << "\n";
355
  fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
356
- fout << text << "\n\n";
357
  }
358
 
359
  return true;
@@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
390
  return escaped;
391
  }
392
 
393
- bool output_csv(struct whisper_context * ctx, const char * fname) {
394
  std::ofstream fout(fname);
395
  if (!fout.is_open()) {
396
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -400,7 +435,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
400
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
401
 
402
  const int n_segments = whisper_full_n_segments(ctx);
403
- fout << "start,end,text\n";
 
 
 
 
 
 
404
  for (int i = 0; i < n_segments; ++i) {
405
  const char * text = whisper_full_get_segment_text(ctx, i);
406
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
@@ -408,13 +449,18 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
408
  char * text_escaped = escape_double_quotes_and_backslashes(text);
409
 
410
  //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
411
- fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
 
 
 
 
 
412
  }
413
 
414
  return true;
415
  }
416
 
417
- bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
418
  std::ofstream fout(fname);
419
  int indent = 0;
420
 
@@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
530
  value_i("from", t0 * 10, false);
531
  value_i("to", t1 * 10, true);
532
  end_obj(false);
533
- value_s("text", text, true);
 
 
 
 
534
  end_obj(i == (n_segments - 1));
535
  }
536
 
@@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
542
  // karaoke video generation
543
  // outputs a bash script that uses ffmpeg to generate a video with the subtitles
544
  // TODO: font parameter adjustments
545
- bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
546
  std::ofstream fout(fname);
547
 
548
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
579
  fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
580
 
581
  bool is_first = true;
 
 
 
 
 
582
 
583
  for (int j = 0; j < n; ++j) {
584
  const auto & token = tokens[j];
@@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
587
  continue;
588
  }
589
 
590
- std::string txt_bg;
591
- std::string txt_fg; // highlight token
592
- std::string txt_ul; // underline
593
 
594
- txt_bg = "> ";
595
- txt_fg = "> ";
596
- txt_ul = "\\ \\ ";
 
 
 
 
 
 
597
 
598
  {
599
  for (int k = 0; k < n; ++k) {
@@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
656
  return true;
657
  }
658
 
659
- bool output_lrc(struct whisper_context * ctx, const char * fname) {
660
-
661
  std::ofstream fout(fname);
662
  if (!fout.is_open()) {
663
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
682
  char buf[16];
683
  snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
684
  std::string timestamp_lrc = std::string(buf);
 
 
 
 
 
 
 
 
685
 
686
- fout << '[' << timestamp_lrc << ']' << text << "\n";
687
  }
688
 
689
  return true;
@@ -828,43 +896,43 @@ int main(int argc, char ** argv) {
828
  // output to text file
829
  if (params.output_txt) {
830
  const auto fname_txt = fname_out + ".txt";
831
- output_txt(ctx, fname_txt.c_str());
832
  }
833
 
834
  // output to VTT file
835
  if (params.output_vtt) {
836
  const auto fname_vtt = fname_out + ".vtt";
837
- output_vtt(ctx, fname_vtt.c_str());
838
  }
839
 
840
  // output to SRT file
841
  if (params.output_srt) {
842
  const auto fname_srt = fname_out + ".srt";
843
- output_srt(ctx, fname_srt.c_str(), params);
844
  }
845
 
846
  // output to WTS file
847
  if (params.output_wts) {
848
  const auto fname_wts = fname_out + ".wts";
849
- output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
850
  }
851
 
852
  // output to CSV file
853
  if (params.output_csv) {
854
  const auto fname_csv = fname_out + ".csv";
855
- output_csv(ctx, fname_csv.c_str());
856
  }
857
 
858
  // output to JSON file
859
  if (params.output_jsn) {
860
  const auto fname_jsn = fname_out + ".json";
861
- output_json(ctx, fname_jsn.c_str(), params);
862
  }
863
 
864
  // output to LRC file
865
  if (params.output_lrc) {
866
  const auto fname_lrc = fname_out + ".lrc";
867
- output_lrc(ctx, fname_lrc.c_str());
868
  }
869
  }
870
  }
 
210
  const std::vector<std::vector<float>> * pcmf32s;
211
  };
212
 
213
+ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
214
+ std::string speaker = "";
215
+ const int64_t n_samples = pcmf32s[0].size();
216
+
217
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
218
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
219
+
220
+ double energy0 = 0.0f;
221
+ double energy1 = 0.0f;
222
+
223
+ for (int64_t j = is0; j < is1; j++) {
224
+ energy0 += fabs(pcmf32s[0][j]);
225
+ energy1 += fabs(pcmf32s[1][j]);
226
+ }
227
+
228
+ if (energy0 > 1.1*energy1) {
229
+ speaker = "0";
230
+ } else if (energy1 > 1.1*energy0) {
231
+ speaker = "1";
232
+ } else {
233
+ speaker = "?";
234
+ }
235
+
236
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
237
+
238
+ if (!id_only) {
239
+ speaker.insert(0, "(speaker ");
240
+ speaker.append(")");
241
+ }
242
+
243
+ return speaker;
244
+ }
245
+
246
  void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
247
  const auto & params = *((whisper_print_user_data *) user_data)->params;
248
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
272
  }
273
 
274
  if (params.diarize && pcmf32s.size() == 2) {
275
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  }
277
 
278
  if (params.print_colors) {
 
306
  }
307
  }
308
 
309
+ bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
310
  std::ofstream fout(fname);
311
  if (!fout.is_open()) {
312
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
 
318
  const int n_segments = whisper_full_n_segments(ctx);
319
  for (int i = 0; i < n_segments; ++i) {
320
  const char * text = whisper_full_get_segment_text(ctx, i);
321
+ std::string speaker = "";
322
+
323
+ if (params.diarize && pcmf32s.size() == 2)
324
+ {
325
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
326
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
327
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
328
+ }
329
+
330
+ fout << speaker << text << "\n";
331
  }
332
 
333
  return true;
334
  }
335
 
336
+ bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
337
  std::ofstream fout(fname);
338
  if (!fout.is_open()) {
339
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
 
349
  const char * text = whisper_full_get_segment_text(ctx, i);
350
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
351
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
352
+ std::string speaker = "";
353
+
354
+ if (params.diarize && pcmf32s.size() == 2)
355
+ {
356
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
357
+ speaker.insert(0, "<v Speaker");
358
+ speaker.append(">");
359
+ }
360
 
361
  fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
362
+ fout << speaker << text << "\n\n";
363
  }
364
 
365
  return true;
366
  }
367
 
368
+ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
369
  std::ofstream fout(fname);
370
  if (!fout.is_open()) {
371
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
 
379
  const char * text = whisper_full_get_segment_text(ctx, i);
380
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
381
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
382
+ std::string speaker = "";
383
+
384
+ if (params.diarize && pcmf32s.size() == 2)
385
+ {
386
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
387
+ }
388
 
389
  fout << i + 1 + params.offset_n << "\n";
390
  fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
391
+ fout << speaker << text << "\n\n";
392
  }
393
 
394
  return true;
 
425
  return escaped;
426
  }
427
 
428
+ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
429
  std::ofstream fout(fname);
430
  if (!fout.is_open()) {
431
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
 
435
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
436
 
437
  const int n_segments = whisper_full_n_segments(ctx);
438
+ fout << "start,end,";
439
+ if (params.diarize && pcmf32s.size() == 2)
440
+ {
441
+ fout << "speaker,";
442
+ }
443
+ fout << "text\n";
444
+
445
  for (int i = 0; i < n_segments; ++i) {
446
  const char * text = whisper_full_get_segment_text(ctx, i);
447
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
 
449
  char * text_escaped = escape_double_quotes_and_backslashes(text);
450
 
451
  //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
452
+ fout << 10 * t0 << "," << 10 * t1 << ",";
453
+ if (params.diarize && pcmf32s.size() == 2)
454
+ {
455
+ fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
456
+ }
457
+ fout << "\"" << text_escaped << "\"\n";
458
  }
459
 
460
  return true;
461
  }
462
 
463
+ bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
464
  std::ofstream fout(fname);
465
  int indent = 0;
466
 
 
576
  value_i("from", t0 * 10, false);
577
  value_i("to", t1 * 10, true);
578
  end_obj(false);
579
+ value_s("text", text, !params.diarize);
580
+
581
+ if (params.diarize && pcmf32s.size() == 2) {
582
+ value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
583
+ }
584
  end_obj(i == (n_segments - 1));
585
  }
586
 
 
592
  // karaoke video generation
593
  // outputs a bash script that uses ffmpeg to generate a video with the subtitles
594
  // TODO: font parameter adjustments
595
+ bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
596
  std::ofstream fout(fname);
597
 
598
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
629
  fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
630
 
631
  bool is_first = true;
632
+ std::string speaker = "";
633
+
634
+ if (params.diarize && pcmf32s.size() == 2) {
635
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
636
+ }
637
 
638
  for (int j = 0; j < n; ++j) {
639
  const auto & token = tokens[j];
 
642
  continue;
643
  }
644
 
645
+ std::string txt_bg = "";
646
+ std::string txt_fg = ""; // highlight token
647
+ std::string txt_ul = ""; // underline
648
 
649
+ if (params.diarize && pcmf32s.size() == 2) {
650
+ txt_bg = speaker;
651
+ txt_fg = speaker;
652
+ txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
653
+ }
654
+
655
+ txt_bg.append("> ");
656
+ txt_fg.append("> ");
657
+ txt_ul.append("\\ \\ ");
658
 
659
  {
660
  for (int k = 0; k < n; ++k) {
 
717
  return true;
718
  }
719
 
720
+ bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
 
721
  std::ofstream fout(fname);
722
  if (!fout.is_open()) {
723
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
 
742
  char buf[16];
743
  snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
744
  std::string timestamp_lrc = std::string(buf);
745
+ std::string speaker = "";
746
+
747
+ if (params.diarize && pcmf32s.size() == 2)
748
+ {
749
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
750
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
751
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
752
+ }
753
 
754
+ fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
755
  }
756
 
757
  return true;
 
896
  // output to text file
897
  if (params.output_txt) {
898
  const auto fname_txt = fname_out + ".txt";
899
+ output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
900
  }
901
 
902
  // output to VTT file
903
  if (params.output_vtt) {
904
  const auto fname_vtt = fname_out + ".vtt";
905
+ output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
906
  }
907
 
908
  // output to SRT file
909
  if (params.output_srt) {
910
  const auto fname_srt = fname_out + ".srt";
911
+ output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
912
  }
913
 
914
  // output to WTS file
915
  if (params.output_wts) {
916
  const auto fname_wts = fname_out + ".wts";
917
+ output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
918
  }
919
 
920
  // output to CSV file
921
  if (params.output_csv) {
922
  const auto fname_csv = fname_out + ".csv";
923
+ output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
924
  }
925
 
926
  // output to JSON file
927
  if (params.output_jsn) {
928
  const auto fname_jsn = fname_out + ".json";
929
+ output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
930
  }
931
 
932
  // output to LRC file
933
  if (params.output_lrc) {
934
  const auto fname_lrc = fname_out + ".lrc";
935
+ output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
936
  }
937
  }
938
  }