ggerganov commited on
Commit
542cfc1
·
unverified ·
1 Parent(s): e079132

whisper : slightly faster Log Mel computation + n-1 FFT threads (#568)

Browse files
Files changed (1) hide show
  1. whisper.cpp +33 -19
whisper.cpp CHANGED
@@ -2306,10 +2306,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2306
  std::vector<float> fft_in(fft_size, 0.0);
2307
  std::vector<float> fft_out(2 * fft_size);
2308
  int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
2309
-
2310
  for (int i = ith; i < mel.n_len; i += n_threads) {
2311
  const int offset = i * fft_step;
2312
-
2313
  // apply Hanning window
2314
  for (int j = 0; j < fft_size; j++) {
2315
  if (offset + j < n_samples) {
@@ -2318,37 +2318,49 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2318
  fft_in[j] = 0.0;
2319
  }
2320
  }
2321
-
2322
  // FFT -> mag^2
2323
  fft(fft_in, fft_out);
2324
-
2325
  for (int j = 0; j < fft_size; j++) {
2326
  fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2327
  }
2328
  for (int j = 1; j < fft_size / 2; j++) {
2329
  fft_out[j] += fft_out[fft_size - j];
2330
  }
2331
-
2332
  if (speed_up) {
2333
  // scale down in the frequency domain results in a speed up in the time domain
2334
  for (int j = 0; j < n_fft; j++) {
2335
  fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
2336
  }
2337
  }
2338
-
2339
  // mel spectrogram
2340
  for (int j = 0; j < mel.n_mel; j++) {
2341
  double sum = 0.0;
2342
-
2343
- for (int k = 0; k < n_fft; k++) {
 
 
 
 
 
 
 
 
 
 
 
2344
  sum += fft_out[k] * filters.data[j * n_fft + k];
2345
  }
 
2346
  if (sum < 1e-10) {
2347
  sum = 1e-10;
2348
  }
2349
-
2350
  sum = log10(sum);
2351
-
2352
  mel.data[j * mel.n_len + i] = sum;
2353
  }
2354
  }
@@ -2383,17 +2395,19 @@ static bool log_mel_spectrogram(
2383
  //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2384
  //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2385
 
2386
- if (n_threads == 1) {
2387
- log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2388
- } else {
2389
- std::vector<std::thread> workers(n_threads);
2390
- for (int iw = 0; iw < n_threads; ++iw) {
2391
- workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
2392
- n_samples, fft_size, fft_step, n_threads,
2393
- std::cref(filters), speed_up, std::ref(mel));
2394
  }
2395
 
2396
- for (int iw = 0; iw < n_threads; ++iw) {
 
 
 
2397
  workers[iw].join();
2398
  }
2399
  }
 
2306
  std::vector<float> fft_in(fft_size, 0.0);
2307
  std::vector<float> fft_out(2 * fft_size);
2308
  int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
2309
+
2310
  for (int i = ith; i < mel.n_len; i += n_threads) {
2311
  const int offset = i * fft_step;
2312
+
2313
  // apply Hanning window
2314
  for (int j = 0; j < fft_size; j++) {
2315
  if (offset + j < n_samples) {
 
2318
  fft_in[j] = 0.0;
2319
  }
2320
  }
2321
+
2322
  // FFT -> mag^2
2323
  fft(fft_in, fft_out);
2324
+
2325
  for (int j = 0; j < fft_size; j++) {
2326
  fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2327
  }
2328
  for (int j = 1; j < fft_size / 2; j++) {
2329
  fft_out[j] += fft_out[fft_size - j];
2330
  }
2331
+
2332
  if (speed_up) {
2333
  // scale down in the frequency domain results in a speed up in the time domain
2334
  for (int j = 0; j < n_fft; j++) {
2335
  fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
2336
  }
2337
  }
2338
+
2339
  // mel spectrogram
2340
  for (int j = 0; j < mel.n_mel; j++) {
2341
  double sum = 0.0;
2342
+
2343
+ // unroll loop (suggested by GH user @lunixbochs)
2344
+ int k = 0;
2345
+ for (k = 0; k < n_fft - 3; k += 4) {
2346
+ sum +=
2347
+ fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
2348
+ fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
2349
+ fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
2350
+ fft_out[k + 3] * filters.data[j*n_fft + k + 3];
2351
+ }
2352
+
2353
+ // handle n_fft remainder
2354
+ for (; k < n_fft; k++) {
2355
  sum += fft_out[k] * filters.data[j * n_fft + k];
2356
  }
2357
+
2358
  if (sum < 1e-10) {
2359
  sum = 1e-10;
2360
  }
2361
+
2362
  sum = log10(sum);
2363
+
2364
  mel.data[j * mel.n_len + i] = sum;
2365
  }
2366
  }
 
2395
  //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2396
  //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2397
 
2398
+ {
2399
+ std::vector<std::thread> workers(n_threads - 1);
2400
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2401
+ workers[iw] = std::thread(
2402
+ log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
2403
+ n_samples, fft_size, fft_step, n_threads,
2404
+ std::cref(filters), speed_up, std::ref(mel));
 
2405
  }
2406
 
2407
+ // main thread
2408
+ log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2409
+
2410
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2411
  workers[iw].join();
2412
  }
2413
  }