ggerganov commited on
Commit
d0e40a2
·
unverified ·
1 Parent(s): 80cee92

ref #4 : added transcription timestamps

Browse files

Can be turned off with "-nt" argument.
Performance has also improved.

Files changed (2) hide show
  1. README.md +75 -16
  2. main.cpp +160 -40
README.md CHANGED
@@ -31,7 +31,7 @@ $ make base.en
31
 
32
  gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
33
  g++ -pthread -O3 -std=c++11 -c main.cpp
34
- g++ -o main ggml.o main.o
35
  ./main -h
36
 
37
  usage: ./main [options]
@@ -40,22 +40,17 @@ options:
40
  -h, --help show this help message and exit
41
  -s SEED, --seed SEED RNG seed (default: -1)
42
  -t N, --threads N number of threads to use during computation (default: 4)
43
- -T N, --tokens N maximum number of tokens to generate per iteration (default: 64)
44
  -v, --verbose verbose output
45
  --translate translate from source language to english
46
  -ps, --print_special print special tokens
 
47
  -l LANG, --language LANG spoken language (default: en)
48
  -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
49
  -f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav)
50
 
51
  bash ./download-ggml-model.sh base.en
52
  Downloading ggml model base.en ...
53
- models/ggml-base.en.bin 100%[=====================================>] 141.11M 7.84MB/s in 18s
54
- Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
55
- You can now use it like this:
56
-
57
- $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
58
-
59
 
60
  ===============================================
61
  Running base.en on all samples in ./samples ...
@@ -86,16 +81,17 @@ whisper_model_load: model size = 140.54 MB
86
  log_mel_spectrogram: n_sample = 176000, n_len = 1100
87
  log_mel_spectrogram: recording length: 11.000000 s
88
 
89
- main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe ...
90
 
91
- And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
92
 
93
- main: load time = 71.89 ms
94
- main: mel time = 36.95 ms
 
95
  main: sample time = 2.10 ms
96
- main: encode time = 700.94 ms / 116.82 ms per layer
97
- main: decode time = 86.14 ms
98
- main: total time = 898.72 ms
99
  ```
100
 
101
  The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@@ -131,10 +127,73 @@ For example, you can use `ffmpeg` like this:
131
  ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
132
  ```
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ## Limitations
135
 
136
  - Very basic greedy sampling scheme - always pick up the top token
137
- - No timestamps
138
  - Inference only
139
  - Runs on the CPU
140
  - Only mono-channel 16-bit WAV is supported
 
31
 
32
  gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
33
  g++ -pthread -O3 -std=c++11 -c main.cpp
34
+ g++ -pthread -o main ggml.o main.o
35
  ./main -h
36
 
37
  usage: ./main [options]
 
40
  -h, --help show this help message and exit
41
  -s SEED, --seed SEED RNG seed (default: -1)
42
  -t N, --threads N number of threads to use during computation (default: 4)
 
43
  -v, --verbose verbose output
44
  --translate translate from source language to english
45
  -ps, --print_special print special tokens
46
+ -nt, --no_timestamps do not print timestamps
47
  -l LANG, --language LANG spoken language (default: en)
48
  -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
49
  -f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav)
50
 
51
  bash ./download-ggml-model.sh base.en
52
  Downloading ggml model base.en ...
53
+ Model base.en already exists. Skipping download.
 
 
 
 
 
54
 
55
  ===============================================
56
  Running base.en on all samples in ./samples ...
 
81
  log_mel_spectrogram: n_sample = 176000, n_len = 1100
82
  log_mel_spectrogram: recording length: 11.000000 s
83
 
84
+ main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe, timestamps = 1 ...
85
 
86
+ [00:00.000 --> 00:11.000] And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
87
 
88
+
89
+ main: load time = 61.78 ms
90
+ main: mel time = 41.74 ms
91
  main: sample time = 2.10 ms
92
+ main: encode time = 718.60 ms / 119.77 ms per layer
93
+ main: decode time = 83.55 ms
94
+ main: total time = 908.15 ms
95
  ```
96
 
97
  The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
 
127
  ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
128
  ```
129
 
130
+ Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg) in less than a minute, using `medium.en` model:
131
+
132
+ ```bash
133
+ $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
134
+ whisper_model_load: loading model from 'models/ggml-medium.en.bin'
135
+ whisper_model_load: n_vocab = 51864
136
+ whisper_model_load: n_audio_ctx = 1500
137
+ whisper_model_load: n_audio_state = 1024
138
+ whisper_model_load: n_audio_head = 16
139
+ whisper_model_load: n_audio_layer = 24
140
+ whisper_model_load: n_text_ctx = 448
141
+ whisper_model_load: n_text_state = 1024
142
+ whisper_model_load: n_text_head = 16
143
+ whisper_model_load: n_text_layer = 24
144
+ whisper_model_load: n_mels = 80
145
+ whisper_model_load: f16 = 1
146
+ whisper_model_load: type = 4
147
+ whisper_model_load: mem_required = 2786.00 MB
148
+ whisper_model_load: adding 1607 extra tokens
149
+ whisper_model_load: ggml ctx size = 1644.97 MB
150
+ whisper_model_load: memory size = 182.62 MB
151
+ whisper_model_load: model size = 1462.12 MB
152
+ log_mel_spectrogram: n_sample = 3179750, n_len = 19873
153
+ log_mel_spectrogram: recording length: 198.734375 s
154
+
155
+ main: processing 3179750 samples (198.7 sec), 8 threads, lang = english, task = transcribe, timestamps = 1 ...
156
+
157
+ [00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
158
+ [00:08.000 --> 00:17.000] At 9 o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
159
+ [00:17.000 --> 00:24.000] A short time later, debris was seen falling from the skies above Texas.
160
+ [00:24.000 --> 00:29.000] The Columbia's lost. There are no survivors.
161
+ [00:29.000 --> 00:32.000] On board was a crew of seven.
162
+ [00:32.000 --> 00:43.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool,
163
+ [00:43.000 --> 00:52.000] Dr. Kultner Aschavla, and Elon Ramon, a Colonel in the Israeli Air Force.
164
+ [00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
165
+ [00:58.000 --> 01:06.000] In an age when space flight has come to seem almost routine, it is easy to overlook the dangers of travel by rocket
166
+ [01:06.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
167
+ [01:12.000 --> 01:22.000] These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life.
168
+ [01:22.000 --> 01:30.000] Because of their courage, endearing, and idealism, we will miss them all the more.
169
+ [01:30.000 --> 01:40.000] All Americans today are thinking as well of the families of these men and women who have been given this sudden shock and grief.
170
+ [01:40.000 --> 01:45.000] You're not alone. Our entire nation agrees with you.
171
+ [01:45.000 --> 01:52.000] And those you love will always have the respect and gratitude of this country.
172
+ [01:52.000 --> 01:56.000] The cause in which they died will continue.
173
+ [01:56.000 --> 02:07.000] Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand.
174
+ [02:07.000 --> 02:11.000] Our journey into space will go on.
175
+ [02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
176
+ [02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
177
+ [02:22.000 --> 02:31.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens who created all these.
178
+ [02:31.000 --> 02:39.000] He who brings out the starry hosts one by one and calls them each by name."
179
+ [02:39.000 --> 02:46.000] Because of his great power and mighty strength, not one of them is missing.
180
+ [02:46.000 --> 02:55.000] The same creator who names the stars also knows the names of the seven souls we mourn today.
181
+ [02:55.000 --> 03:05.000] The crew of the shuttle Columbia did not return safely to Earth, yet we can pray that all are safely home.
182
+ [03:05.000 --> 03:14.000] May God bless the grieving families and may God continue to bless America.
183
+ [03:14.000 --> 03:24.000] [Music]
184
+
185
+
186
+ main: load time = 438.55 ms
187
+ main: mel time = 440.22 ms
188
+ main: sample time = 32.23 ms
189
+ main: encode time = 42329.63 ms / 1763.73 ms per layer
190
+ main: decode time = 15190.00 ms
191
+ main: total time = 58444.63 ms
192
+ ```
193
+
194
  ## Limitations
195
 
196
  - Very basic greedy sampling scheme - always pick up the top token
 
197
  - Inference only
198
  - Runs on the CPU
199
  - Only mono-channel 16-bit WAV is supported
main.cpp CHANGED
@@ -206,6 +206,7 @@ struct whisper_vocab {
206
  id token_sot = 50257;
207
  id token_prev = 50360;
208
  id token_solm = 50361; // ??
 
209
  id token_beg = 50363;
210
 
211
  // available tasks
@@ -217,17 +218,20 @@ struct whisper_vocab {
217
  }
218
  };
219
 
 
 
 
 
 
220
  // command-line parameters
221
  struct whisper_params {
222
  int32_t seed = -1; // RNG seed, not used currently
223
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
224
 
225
- // sampling parameter - used for the greedy strategy
226
- int32_t max_tokens_per_iter = 64;
227
-
228
  bool verbose = false;
229
  bool translate = false;
230
  bool print_special_tokens = false;
 
231
 
232
  std::string language = "en";
233
  std::string model = "models/ggml-base.en.bin";
@@ -244,8 +248,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
244
  params.seed = std::stoi(argv[++i]);
245
  } else if (arg == "-t" || arg == "--threads") {
246
  params.n_threads = std::stoi(argv[++i]);
247
- } else if (arg == "-T" || arg == "--tokens") {
248
- params.max_tokens_per_iter = std::stoi(argv[++i]);
249
  } else if (arg == "-v" || arg == "--verbose") {
250
  params.verbose = true;
251
  } else if (arg == "--translate") {
@@ -259,6 +261,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
259
  }
260
  } else if (arg == "-ps" || arg == "--print_special") {
261
  params.print_special_tokens = true;
 
 
262
  } else if (arg == "-m" || arg == "--model") {
263
  params.model = argv[++i];
264
  } else if (arg == "-f" || arg == "--file") {
@@ -284,10 +288,10 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
284
  fprintf(stderr, " -h, --help show this help message and exit\n");
285
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
286
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
287
- fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
288
  fprintf(stderr, " -v, --verbose verbose output\n");
289
  fprintf(stderr, " --translate translate from source language to english\n");
290
  fprintf(stderr, " -ps, --print_special print special tokens\n");
 
291
  fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
292
  fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
293
  fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
@@ -591,6 +595,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
591
  vocab.token_sot++;
592
  vocab.token_prev++;
593
  vocab.token_solm++;
 
594
  vocab.token_beg++;
595
  }
596
 
@@ -605,6 +610,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
605
  word = "[_SOT_]";
606
  } else if (i == vocab.token_prev) {
607
  word = "[_PREV_]";
 
 
608
  } else if (i == vocab.token_beg) {
609
  word = "[_BEG_]";
610
  } else {
@@ -1842,15 +1849,13 @@ bool whisper_decode(
1842
  // TODO: temperature
1843
  whisper_vocab::id whisper_sample_best(
1844
  const whisper_vocab & vocab,
1845
- const float * probs,
1846
- double temp,
1847
- int offset = 0) {
1848
  int n_logits = vocab.id_to_token.size();
1849
 
1850
  std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1851
  probs_id.reserve(n_logits);
1852
 
1853
- for (int i = offset; i < n_logits; i++) {
1854
  probs_id.push_back(std::make_pair(probs[i], i));
1855
  }
1856
 
@@ -1872,13 +1877,49 @@ whisper_vocab::id whisper_sample_best(
1872
  //}
1873
 
1874
  int res = 0;
1875
- while (probs_id[res].second == vocab.token_solm && res < (int) probs_id.size() - 1) {
 
 
 
1876
  res++;
1877
  }
1878
 
1879
  return probs_id[res].second;
1880
  }
1881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1882
  // Cooley-Tukey FFT
1883
  // poor man's implmentation - use something better
1884
  // input is real-valued
@@ -2032,6 +2073,20 @@ bool log_mel_spectrogram(
2032
  return true;
2033
  }
2034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2035
  int main(int argc, char ** argv) {
2036
  const int64_t t_main_start_us = ggml_time_us();
2037
 
@@ -2051,7 +2106,7 @@ int main(int argc, char ** argv) {
2051
 
2052
  int64_t t_load_us = 0;
2053
  int64_t t_mel_us = 0;
2054
- int64_t t_sample_us = 0;
2055
  int64_t t_encode_us = 0;
2056
  int64_t t_decode_us = 0;
2057
 
@@ -2128,10 +2183,12 @@ int main(int argc, char ** argv) {
2128
  printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
2129
  }
2130
  }
2131
- printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n",
2132
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
2133
  g_lang.at(params.language).second.c_str(),
2134
- params.translate ? "translate" : "transcribe");
 
 
2135
  }
2136
 
2137
  // the accumulated text context so far
@@ -2148,6 +2205,9 @@ int main(int argc, char ** argv) {
2148
  }
2149
  }
2150
 
 
 
 
2151
  // main loop
2152
  int seek = 0;
2153
  while (true) {
@@ -2165,7 +2225,7 @@ int main(int argc, char ** argv) {
2165
  return 1;
2166
  }
2167
 
2168
- t_encode_us = ggml_time_us() - t_start_us;
2169
  }
2170
 
2171
  std::vector<float> probs;
@@ -2192,11 +2252,16 @@ int main(int argc, char ** argv) {
2192
  int seek_delta = 100*CHUNK_SIZE;
2193
  whisper_vocab::id last_id = 0;
2194
 
 
2195
  //for (int i = 0; i < prompt.size(); i++) {
2196
  // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
2197
  //}
 
 
 
 
 
2198
 
2199
- printf("\n");
2200
  for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) {
2201
  // decode
2202
  if (prompt.size() > 0) {
@@ -2216,63 +2281,118 @@ int main(int argc, char ** argv) {
2216
  // very basic greedy sampling strategy:
2217
  //
2218
  // - always take the most probable token
2219
- // - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token
2220
- // and advance the sliding window by that amount
2221
- // - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too
2222
  //
2223
  // more sophisticated sampling strategies could be implemented here, but we keep it simple
2224
  // feel free to experiment!
2225
  //
2226
  {
2227
- // sample next token
2228
- const float temp = 1.0; // TODO
2229
-
2230
  const int n_vocab = model.hparams.n_vocab;
2231
 
2232
- whisper_vocab::id id = 0;
 
2233
 
2234
  {
2235
  const int64_t t_start_sample_us = ggml_time_us();
2236
 
2237
- id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), temp, i > params.max_tokens_per_iter ? vocab.token_beg : 0);
 
 
 
2238
 
2239
  t_sample_us += ggml_time_us() - t_start_sample_us;
2240
  }
2241
 
2242
- // end of text token
2243
- if (id == vocab.token_eot) {
2244
- break;
2245
- }
2246
-
2247
- // 2 consecutive time tokens
2248
- if (id > vocab.token_beg && last_id > vocab.token_beg) {
2249
  seek_delta = 2*(id - vocab.token_beg);
2250
- done = true;
2251
  }
2252
  last_id = id;
2253
 
2254
  // add it to the context
2255
  prompt.push_back(id);
2256
- prompt_past.push_back(id);
2257
- }
2258
 
2259
- // display text
2260
- for (auto id : prompt) {
2261
- if (params.print_special_tokens == false && id >= vocab.token_eot) {
2262
- continue;
2263
  }
2264
- printf("%s", vocab.id_to_token[id].c_str());
2265
  }
2266
- fflush(stdout);
2267
 
2268
  if (done) {
2269
  break;
2270
  }
2271
  }
2272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2273
  seek += seek_delta;
2274
  }
2275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2276
  // report timing
2277
  {
2278
  const int64_t t_main_end_us = ggml_time_us();
 
206
  id token_sot = 50257;
207
  id token_prev = 50360;
208
  id token_solm = 50361; // ??
209
+ id token_not = 50362; // no timestamps
210
  id token_beg = 50363;
211
 
212
  // available tasks
 
218
  }
219
  };
220
 
221
+ struct whisper_result {
222
+ whisper_vocab::id id;
223
+ int64_t t;
224
+ };
225
+
226
  // command-line parameters
227
  struct whisper_params {
228
  int32_t seed = -1; // RNG seed, not used currently
229
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
230
 
 
 
 
231
  bool verbose = false;
232
  bool translate = false;
233
  bool print_special_tokens = false;
234
+ bool no_timestamps = false;
235
 
236
  std::string language = "en";
237
  std::string model = "models/ggml-base.en.bin";
 
248
  params.seed = std::stoi(argv[++i]);
249
  } else if (arg == "-t" || arg == "--threads") {
250
  params.n_threads = std::stoi(argv[++i]);
 
 
251
  } else if (arg == "-v" || arg == "--verbose") {
252
  params.verbose = true;
253
  } else if (arg == "--translate") {
 
261
  }
262
  } else if (arg == "-ps" || arg == "--print_special") {
263
  params.print_special_tokens = true;
264
+ } else if (arg == "-nt" || arg == "--no_timestamps") {
265
+ params.no_timestamps = true;
266
  } else if (arg == "-m" || arg == "--model") {
267
  params.model = argv[++i];
268
  } else if (arg == "-f" || arg == "--file") {
 
288
  fprintf(stderr, " -h, --help show this help message and exit\n");
289
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
290
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
 
291
  fprintf(stderr, " -v, --verbose verbose output\n");
292
  fprintf(stderr, " --translate translate from source language to english\n");
293
  fprintf(stderr, " -ps, --print_special print special tokens\n");
294
+ fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
295
  fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
296
  fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
297
  fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
 
595
  vocab.token_sot++;
596
  vocab.token_prev++;
597
  vocab.token_solm++;
598
+ vocab.token_not++;
599
  vocab.token_beg++;
600
  }
601
 
 
610
  word = "[_SOT_]";
611
  } else if (i == vocab.token_prev) {
612
  word = "[_PREV_]";
613
+ } else if (i == vocab.token_not) {
614
+ word = "[_NOT_]";
615
  } else if (i == vocab.token_beg) {
616
  word = "[_BEG_]";
617
  } else {
 
1849
  // TODO: temperature
1850
  whisper_vocab::id whisper_sample_best(
1851
  const whisper_vocab & vocab,
1852
+ const float * probs) {
 
 
1853
  int n_logits = vocab.id_to_token.size();
1854
 
1855
  std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1856
  probs_id.reserve(n_logits);
1857
 
1858
+ for (int i = 0; i < n_logits; i++) {
1859
  probs_id.push_back(std::make_pair(probs[i], i));
1860
  }
1861
 
 
1877
  //}
1878
 
1879
  int res = 0;
1880
+ while ((probs_id[res].second == vocab.token_sot ||
1881
+ probs_id[res].second == vocab.token_solm ||
1882
+ probs_id[res].second == vocab.token_not) &&
1883
+ res < (int) probs_id.size() - 1) {
1884
  res++;
1885
  }
1886
 
1887
  return probs_id[res].second;
1888
  }
1889
 
1890
+ // samples only from the timestamps tokens
1891
+ whisper_vocab::id whisper_sample_timestamp(
1892
+ const whisper_vocab & vocab,
1893
+ const float * probs) {
1894
+ int n_logits = vocab.id_to_token.size();
1895
+
1896
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1897
+ probs_id.reserve(n_logits);
1898
+
1899
+ for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1900
+ probs_id.push_back(std::make_pair(probs[i], i));
1901
+ }
1902
+
1903
+ const int top_k = 10;
1904
+
1905
+ // find the top K tokens
1906
+ std::partial_sort(
1907
+ probs_id.begin(),
1908
+ probs_id.begin() + top_k, probs_id.end(),
1909
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1910
+ return a.first > b.first;
1911
+ });
1912
+
1913
+ probs_id.resize(top_k);
1914
+
1915
+ //printf("\n");
1916
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1917
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1918
+ //}
1919
+
1920
+ return probs_id[0].second;
1921
+ }
1922
+
1923
  // Cooley-Tukey FFT
1924
  // poor man's implmentation - use something better
1925
  // input is real-valued
 
2073
  return true;
2074
  }
2075
 
2076
+ // 500 -> 00:05.000
2077
+ // 6000 -> 01:00.000
2078
+ std::string to_timestamp(int64_t t) {
2079
+ int64_t sec = t/100;
2080
+ int64_t msec = t - sec*100;
2081
+ int64_t min = sec/60;
2082
+ sec = sec - min*60;
2083
+
2084
+ char buf[32];
2085
+ snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
2086
+
2087
+ return std::string(buf);
2088
+ }
2089
+
2090
  int main(int argc, char ** argv) {
2091
  const int64_t t_main_start_us = ggml_time_us();
2092
 
 
2106
 
2107
  int64_t t_load_us = 0;
2108
  int64_t t_mel_us = 0;
2109
+ int64_t t_sample_us = 0;
2110
  int64_t t_encode_us = 0;
2111
  int64_t t_decode_us = 0;
2112
 
 
2183
  printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
2184
  }
2185
  }
2186
+ printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
2187
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
2188
  g_lang.at(params.language).second.c_str(),
2189
+ params.translate ? "translate" : "transcribe",
2190
+ params.no_timestamps ? 0 : 1);
2191
+ printf("\n");
2192
  }
2193
 
2194
  // the accumulated text context so far
 
2205
  }
2206
  }
2207
 
2208
+ // the generated text including timestamps
2209
+ std::vector<whisper_result> result_all;
2210
+
2211
  // main loop
2212
  int seek = 0;
2213
  while (true) {
 
2225
  return 1;
2226
  }
2227
 
2228
+ t_encode_us += ggml_time_us() - t_start_us;
2229
  }
2230
 
2231
  std::vector<float> probs;
 
2252
  int seek_delta = 100*CHUNK_SIZE;
2253
  whisper_vocab::id last_id = 0;
2254
 
2255
+ //printf("\n\n");
2256
  //for (int i = 0; i < prompt.size(); i++) {
2257
  // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
2258
  //}
2259
+ //printf("\n\n");
2260
+
2261
+ // the accumulated transcription in the current interation
2262
+ int result_len = 0;
2263
+ std::vector<whisper_result> result_cur;
2264
 
 
2265
  for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) {
2266
  // decode
2267
  if (prompt.size() > 0) {
 
2281
  // very basic greedy sampling strategy:
2282
  //
2283
  // - always take the most probable token
 
 
 
2284
  //
2285
  // more sophisticated sampling strategies could be implemented here, but we keep it simple
2286
  // feel free to experiment!
2287
  //
2288
  {
 
 
 
2289
  const int n_vocab = model.hparams.n_vocab;
2290
 
2291
+ whisper_vocab::id id = 0;
2292
+ whisper_vocab::id tid = vocab.token_beg;
2293
 
2294
  {
2295
  const int64_t t_start_sample_us = ggml_time_us();
2296
 
2297
+ id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab));
2298
+ if (i > 0) {
2299
+ tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
2300
+ }
2301
 
2302
  t_sample_us += ggml_time_us() - t_start_sample_us;
2303
  }
2304
 
2305
+ // update sliding window
2306
+ if (id > vocab.token_beg) {
 
 
 
 
 
2307
  seek_delta = 2*(id - vocab.token_beg);
2308
+ result_len = i + 1;
2309
  }
2310
  last_id = id;
2311
 
2312
  // add it to the context
2313
  prompt.push_back(id);
2314
+ result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
 
2315
 
2316
+ // end of text token
2317
+ if (id == vocab.token_eot) {
2318
+ break;
 
2319
  }
 
2320
  }
 
2321
 
2322
  if (done) {
2323
  break;
2324
  }
2325
  }
2326
 
2327
+ result_cur.resize(result_len);
2328
+ result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
2329
+
2330
+ for (const auto & r : result_cur) {
2331
+ prompt_past.push_back(r.id);
2332
+ }
2333
+
2334
+ // print the text from this iteration
2335
+ if (result_cur.size() > 0) {
2336
+ auto t0 = result_cur.front().t;
2337
+
2338
+ std::string text = "";
2339
+ for (int i = 0; i < result_cur.size(); i++) {
2340
+ if (params.print_special_tokens == false && result_cur[i].id >= vocab.token_eot) {
2341
+ } else {
2342
+ text += vocab.id_to_token[result_cur[i].id];
2343
+ }
2344
+ if (result_cur[i].id > vocab.token_beg) {
2345
+ const auto t1 = result_cur[i].t;
2346
+ if (!text.empty()) {
2347
+ if (params.no_timestamps) {
2348
+ printf ("%s", text.c_str());
2349
+ fflush(stdout);
2350
+ } else {
2351
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
2352
+ }
2353
+ }
2354
+ text = "";
2355
+ while (result_cur[i].id > vocab.token_beg && i < result_cur.size()) {
2356
+ i++;
2357
+ }
2358
+ i--;
2359
+ t0 = result_cur[i].t;
2360
+ }
2361
+ }
2362
+
2363
+ if (!text.empty()) {
2364
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(seek + seek_delta).c_str(), text.c_str());
2365
+ }
2366
+ }
2367
+
2368
  seek += seek_delta;
2369
  }
2370
 
2371
+ // WIP: attempt for per-token timestamps
2372
+ //if (!params.no_timestamps && result_all.size() > 0) {
2373
+ // const int64_t dt = 500; // 5 second intervals
2374
+
2375
+ // int i0 = 0;
2376
+
2377
+ // int64_t t0 = result_all[0].t;
2378
+ // int64_t t1 = t0;
2379
+
2380
+ // printf("\n\n");
2381
+ // for (int i = 0; i < result_all.size(); ++i) {
2382
+ // printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t);
2383
+ // if (result_all[i].t - t0 > dt) {
2384
+ // t1 = result_all[i - 1].t;
2385
+ // printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
2386
+ // for (int j = i0; j < i; ++j) {
2387
+ // printf("%s", vocab.id_to_token.at(result_all[j].id).c_str());
2388
+ // }
2389
+ // printf("\n");
2390
+ // i0 = i;
2391
+ // t0 = result_all[i].t;
2392
+ // }
2393
+ // }
2394
+ //}
2395
+
2396
  // report timing
2397
  {
2398
  const int64_t t_main_end_us = ggml_time_us();