Spaces:
Sleeping
Sleeping
whisper : add CUDA-specific computation mel spectrograms (#2206)
Browse files* whisper : use polymorphic class to calculate mel spectrogram
* whisper : add cuda-specific mel spectrogram calculation
* whisper : conditionally compile cufftGetErrorString to avoid warnings
* build : add new files to makefile
* ruby : add new files to conf script
* build : fix typo in makefile
* whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
- CMakeLists.txt +7 -3
- Makefile +6 -3
- bindings/ruby/ext/extconf.rb +1 -0
- whisper-mel-cuda.cu +342 -0
- whisper-mel-cuda.hpp +3 -0
- whisper-mel.hpp +33 -0
- whisper.cpp +103 -93
- whisper.h +2 -0
CMakeLists.txt
CHANGED
|
@@ -364,12 +364,12 @@ if (WHISPER_CUDA)
|
|
| 364 |
if (WHISPER_STATIC)
|
| 365 |
if (WIN32)
|
| 366 |
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
| 367 |
-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
| 368 |
else ()
|
| 369 |
-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
| 370 |
endif()
|
| 371 |
else()
|
| 372 |
-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
| 373 |
endif()
|
| 374 |
|
| 375 |
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
|
|
@@ -679,6 +679,10 @@ add_library(${TARGET}
|
|
| 679 |
whisper.cpp
|
| 680 |
)
|
| 681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
include_directories (
|
| 683 |
.
|
| 684 |
)
|
|
|
|
| 364 |
if (WHISPER_STATIC)
|
| 365 |
if (WIN32)
|
| 366 |
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
| 367 |
+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
|
| 368 |
else ()
|
| 369 |
+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
|
| 370 |
endif()
|
| 371 |
else()
|
| 372 |
+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
|
| 373 |
endif()
|
| 374 |
|
| 375 |
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
|
|
|
|
| 679 |
whisper.cpp
|
| 680 |
)
|
| 681 |
|
| 682 |
+
if (WHISPER_CUDA)
|
| 683 |
+
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
|
| 684 |
+
endif()
|
| 685 |
+
|
| 686 |
include_directories (
|
| 687 |
.
|
| 688 |
)
|
Makefile
CHANGED
|
@@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
|
|
| 286 |
|
| 287 |
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
| 288 |
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
| 289 |
-
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
| 290 |
-
WHISPER_OBJ += ggml-cuda.o
|
| 291 |
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
| 292 |
NVCC = nvcc
|
| 293 |
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
|
@@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
|
|
| 299 |
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
| 300 |
endif
|
| 301 |
|
|
|
|
|
|
|
|
|
|
| 302 |
ifdef WHISPER_HIPBLAS
|
| 303 |
ROCM_PATH ?= /opt/rocm
|
| 304 |
HIPCC ?= $(ROCM_PATH)/bin/hipcc
|
|
@@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
|
| 404 |
|
| 405 |
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
| 406 |
|
| 407 |
-
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
| 408 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
| 409 |
|
| 410 |
ifndef WHISPER_COREML
|
|
|
|
| 286 |
|
| 287 |
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
| 288 |
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
| 289 |
+
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
| 290 |
+
WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
|
| 291 |
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
| 292 |
NVCC = nvcc
|
| 293 |
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
|
|
|
| 299 |
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
| 300 |
endif
|
| 301 |
|
| 302 |
+
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
|
| 303 |
+
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
| 304 |
+
|
| 305 |
ifdef WHISPER_HIPBLAS
|
| 306 |
ROCM_PATH ?= /opt/rocm
|
| 307 |
HIPCC ?= $(ROCM_PATH)/bin/hipcc
|
|
|
|
| 407 |
|
| 408 |
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
| 409 |
|
| 410 |
+
whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
|
| 411 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
| 412 |
|
| 413 |
ifndef WHISPER_COREML
|
bindings/ruby/ext/extconf.rb
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
require 'mkmf'
|
| 2 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
| 3 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
|
|
|
| 4 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
| 5 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
| 6 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")
|
|
|
|
| 1 |
require 'mkmf'
|
| 2 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
| 3 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
| 4 |
+
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
|
| 5 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
| 6 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
| 7 |
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")
|
whisper-mel-cuda.cu
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
|
| 2 |
+
#include "whisper-mel-cuda.hpp"
|
| 3 |
+
#include "whisper.h"
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
#include <cufft.h>
|
| 8 |
+
#include <cublas_v2.h>
|
| 9 |
+
#include <cuComplex.h>
|
| 10 |
+
#include <cub/device/device_reduce.cuh>
|
| 11 |
+
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
|
| 14 |
+
#if defined(_MSC_VER)
|
| 15 |
+
#pragma warning(disable: 4324) // added padding
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#ifndef NDEBUG
|
| 19 |
+
# define DO_CHECKS 1
|
| 20 |
+
#else
|
| 21 |
+
# define DO_CHECKS 0
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
namespace {
|
| 25 |
+
|
| 26 |
+
#if DO_CHECKS
|
| 27 |
+
const char* cufftGetErrorString(cufftResult_t res) {
|
| 28 |
+
switch (res) {
|
| 29 |
+
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
| 30 |
+
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
| 31 |
+
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
|
| 32 |
+
case CUFFT_INVALID_TYPE: return "No longer used";
|
| 33 |
+
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
|
| 34 |
+
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
|
| 35 |
+
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
|
| 36 |
+
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
|
| 37 |
+
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
|
| 38 |
+
case CUFFT_UNALIGNED_DATA: return "No longer used";
|
| 39 |
+
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
|
| 40 |
+
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
|
| 41 |
+
case CUFFT_PARSE_ERROR: return "Internal plan database error";
|
| 42 |
+
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
|
| 43 |
+
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
|
| 44 |
+
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
|
| 45 |
+
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
|
| 46 |
+
default: return "Unknown error";
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# define CUDA_CHECK_GEN(err, success, error_fn) \
|
| 51 |
+
do { \
|
| 52 |
+
auto err_ = (err); \
|
| 53 |
+
if (err_ != (success)) { \
|
| 54 |
+
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
|
| 55 |
+
} \
|
| 56 |
+
} while (0)
|
| 57 |
+
#else
|
| 58 |
+
# define CUDA_CHECK_GEN(err, success, error_fn) err
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
| 62 |
+
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
|
| 63 |
+
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
| 64 |
+
|
| 65 |
+
__global__ void k_fill_stft_input(
|
| 66 |
+
const float * padded_samples,
|
| 67 |
+
const int n_frames,
|
| 68 |
+
const float * hann_window,
|
| 69 |
+
float * stft_in
|
| 70 |
+
) {
|
| 71 |
+
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
| 72 |
+
// if (y >= n_frames) return;
|
| 73 |
+
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
| 74 |
+
// if (x >= WHISPER_N_FFT) return;
|
| 75 |
+
|
| 76 |
+
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
|
| 77 |
+
auto outLine = stft_in + y * WHISPER_N_FFT;
|
| 78 |
+
|
| 79 |
+
outLine[x] = line[x] * hann_window[x];
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
__global__ void k_calc_magnitudes(
|
| 83 |
+
const cuComplex* stft_out,
|
| 84 |
+
const int n_frames,
|
| 85 |
+
float * magnitudes
|
| 86 |
+
) {
|
| 87 |
+
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
| 88 |
+
// if (y >= n_frames) return;
|
| 89 |
+
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
| 90 |
+
// if (x >= WHISPER_N_FFT_HALF) return;
|
| 91 |
+
|
| 92 |
+
auto idx = y * WHISPER_N_FFT_HALF + x;
|
| 93 |
+
|
| 94 |
+
auto r = stft_out[idx].x;
|
| 95 |
+
auto i = stft_out[idx].y;
|
| 96 |
+
magnitudes[idx] = r * r + i * i;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
__global__ void k_calc_log_mel(
|
| 100 |
+
const float * mel_data,
|
| 101 |
+
const int n_mel,
|
| 102 |
+
const float * max_val,
|
| 103 |
+
float * log_mel
|
| 104 |
+
) {
|
| 105 |
+
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
| 106 |
+
if (x >= n_mel) return;
|
| 107 |
+
|
| 108 |
+
float val = mel_data[x];
|
| 109 |
+
|
| 110 |
+
constexpr float e = 1e-10f;
|
| 111 |
+
if (val < e) val = e;
|
| 112 |
+
|
| 113 |
+
val = log10(val);
|
| 114 |
+
|
| 115 |
+
const float max = log10(*max_val) - 8.f;
|
| 116 |
+
if (val < max) val = max;
|
| 117 |
+
|
| 118 |
+
log_mel[x] = (val + 4) / 4;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
void fill_stft_input(
|
| 122 |
+
const float * padded_samples,
|
| 123 |
+
int n_frames,
|
| 124 |
+
const float * hann_window,
|
| 125 |
+
float * stft_in,
|
| 126 |
+
cudaStream_t stream
|
| 127 |
+
) {
|
| 128 |
+
dim3 block(WHISPER_N_FFT, 1);
|
| 129 |
+
dim3 grid(1, n_frames);
|
| 130 |
+
|
| 131 |
+
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
void calc_magnitudes(
|
| 135 |
+
const cuComplex* stft_out,
|
| 136 |
+
int n_frames,
|
| 137 |
+
float * magnitudes,
|
| 138 |
+
cudaStream_t stream
|
| 139 |
+
) {
|
| 140 |
+
dim3 block(WHISPER_N_FFT_HALF, 1);
|
| 141 |
+
dim3 grid(1, n_frames);
|
| 142 |
+
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
|
| 146 |
+
|
| 147 |
+
size_t get_log_mel_temp_storage_size() {
|
| 148 |
+
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
| 149 |
+
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 150 |
+
constexpr auto maxMels = 160;
|
| 151 |
+
|
| 152 |
+
size_t nbytes = 0;
|
| 153 |
+
float * temp = nullptr;
|
| 154 |
+
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, maxFrames * maxMels);
|
| 155 |
+
return nbytes + LOG_MEL_PREFIX_SIZE;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
void calc_log_mel(
|
| 159 |
+
const float * mel_data,
|
| 160 |
+
int n_mel,
|
| 161 |
+
void * tempStorage,
|
| 162 |
+
int tempStorageSize,
|
| 163 |
+
float * log_mel,
|
| 164 |
+
cudaStream_t stream
|
| 165 |
+
) {
|
| 166 |
+
float * max_val = reinterpret_cast<float *>(tempStorage);
|
| 167 |
+
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
|
| 168 |
+
|
| 169 |
+
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
|
| 170 |
+
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
|
| 171 |
+
|
| 172 |
+
int block = 256;
|
| 173 |
+
int grid = (n_mel + block - 1) / block;
|
| 174 |
+
|
| 175 |
+
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
class mel_calc_cuda : public whisper_mel_calc {
|
| 179 |
+
const int m_n_mel;
|
| 180 |
+
|
| 181 |
+
ggml_backend_t m_backend = nullptr;
|
| 182 |
+
|
| 183 |
+
cudaStream_t m_stream = nullptr;
|
| 184 |
+
cublasHandle_t m_cublas_handle = nullptr;
|
| 185 |
+
|
| 186 |
+
float * m_hann_window = nullptr;
|
| 187 |
+
|
| 188 |
+
size_t m_cufft_workspace_size = 0;
|
| 189 |
+
void * m_cufft_workspace = nullptr;
|
| 190 |
+
|
| 191 |
+
float * m_filters = nullptr;
|
| 192 |
+
|
| 193 |
+
size_t m_log_mel_temp_storage_size = 0;
|
| 194 |
+
void * m_log_mel_temp_storage = nullptr;
|
| 195 |
+
public:
|
| 196 |
+
mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters)
|
| 197 |
+
: m_n_mel(filters.n_mel)
|
| 198 |
+
, m_backend(backend)
|
| 199 |
+
{
|
| 200 |
+
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
| 201 |
+
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
| 202 |
+
}
|
| 203 |
+
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
|
| 204 |
+
|
| 205 |
+
CUDA_CHECK(cudaStreamCreate(&m_stream));
|
| 206 |
+
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
|
| 207 |
+
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
| 208 |
+
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
|
| 209 |
+
|
| 210 |
+
// create Hann window
|
| 211 |
+
{
|
| 212 |
+
auto hw = whisper_mel_calc::hann_window();
|
| 213 |
+
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
|
| 214 |
+
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// create working area
|
| 218 |
+
{
|
| 219 |
+
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
| 220 |
+
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 221 |
+
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, maxFrames, &m_cufft_workspace_size));
|
| 222 |
+
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
// fill filters
|
| 226 |
+
{
|
| 227 |
+
auto& f = filters.data;
|
| 228 |
+
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
|
| 229 |
+
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
{
|
| 233 |
+
m_log_mel_temp_storage_size = get_log_mel_temp_storage_size();
|
| 234 |
+
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
~mel_calc_cuda() {
|
| 239 |
+
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
| 240 |
+
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
| 241 |
+
CUDA_CHECK(cudaFree(m_hann_window));
|
| 242 |
+
CUDA_CHECK(cudaFree(m_cufft_workspace));
|
| 243 |
+
CUDA_CHECK(cudaFree(m_filters));
|
| 244 |
+
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) const override {
|
| 248 |
+
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
| 249 |
+
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
| 250 |
+
|
| 251 |
+
// pad
|
| 252 |
+
std::vector<float> padded_samples(padded_size);
|
| 253 |
+
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
|
| 254 |
+
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
|
| 255 |
+
|
| 256 |
+
// fill the rest of the data
|
| 257 |
+
// it should canonically be mirrored at the end as well,
|
| 258 |
+
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
|
| 259 |
+
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
|
| 260 |
+
|
| 261 |
+
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 262 |
+
|
| 263 |
+
float * cu_padded_samples = nullptr;
|
| 264 |
+
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
|
| 265 |
+
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
| 266 |
+
|
| 267 |
+
float * stft_in = nullptr; // contiguous buffer for stft input
|
| 268 |
+
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
|
| 269 |
+
|
| 270 |
+
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
|
| 271 |
+
|
| 272 |
+
cufftComplex* stft_out;
|
| 273 |
+
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
|
| 274 |
+
|
| 275 |
+
cufftHandle plan;
|
| 276 |
+
CUFFT_CHECK(cufftCreate(&plan));
|
| 277 |
+
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
|
| 278 |
+
{
|
| 279 |
+
size_t waSize;
|
| 280 |
+
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
|
| 281 |
+
assert(waSize <= m_cufft_workspace_size);
|
| 282 |
+
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
|
| 283 |
+
CUFFT_CHECK(cufftSetStream(plan, m_stream));
|
| 284 |
+
}
|
| 285 |
+
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
|
| 286 |
+
|
| 287 |
+
const auto n_mag_frames = n_frames - 1; // drop last frame
|
| 288 |
+
float * magnitudes;
|
| 289 |
+
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
|
| 290 |
+
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
|
| 291 |
+
|
| 292 |
+
float * mel_data = nullptr;
|
| 293 |
+
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
|
| 294 |
+
|
| 295 |
+
const float fone = 1.0f, fzero = 0.0f;
|
| 296 |
+
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
|
| 297 |
+
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
|
| 298 |
+
&fone,
|
| 299 |
+
magnitudes, WHISPER_N_FFT_HALF,
|
| 300 |
+
m_filters, WHISPER_N_FFT_HALF,
|
| 301 |
+
&fzero,
|
| 302 |
+
mel_data, int(n_mag_frames)));
|
| 303 |
+
|
| 304 |
+
float * log_mels = nullptr;
|
| 305 |
+
CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream));
|
| 306 |
+
|
| 307 |
+
calc_log_mel(
|
| 308 |
+
mel_data, int(m_n_mel * n_mag_frames),
|
| 309 |
+
m_log_mel_temp_storage, int(m_log_mel_temp_storage_size),
|
| 310 |
+
log_mels, m_stream);
|
| 311 |
+
|
| 312 |
+
whisper_mel ret;
|
| 313 |
+
ret.n_mel = m_n_mel;
|
| 314 |
+
ret.n_len = int(n_mag_frames);
|
| 315 |
+
// Calculate semi-padded sample length to ensure compatibility
|
| 316 |
+
ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 317 |
+
ret.data.resize(m_n_mel * n_mag_frames);
|
| 318 |
+
CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream));
|
| 319 |
+
|
| 320 |
+
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
| 321 |
+
|
| 322 |
+
// cleanup
|
| 323 |
+
CUFFT_CHECK(cufftDestroy(plan));
|
| 324 |
+
CUDA_CHECK(cudaFreeAsync(log_mels, m_stream));
|
| 325 |
+
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
|
| 326 |
+
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
|
| 327 |
+
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
|
| 328 |
+
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
|
| 329 |
+
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
|
| 330 |
+
|
| 331 |
+
return ret;
|
| 332 |
+
}
|
| 333 |
+
};
|
| 334 |
+
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
| 338 |
+
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
| 339 |
+
return nullptr;
|
| 340 |
+
}
|
| 341 |
+
return new mel_calc_cuda(backend, filters);
|
| 342 |
+
}
|
whisper-mel-cuda.hpp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "whisper-mel.hpp"
|
| 2 |
+
|
| 3 |
+
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);
|
whisper-mel.hpp
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include "ggml-backend.h"
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
struct whisper_mel {
|
| 6 |
+
int n_len;
|
| 7 |
+
int n_len_org;
|
| 8 |
+
int n_mel;
|
| 9 |
+
|
| 10 |
+
std::vector<float> data;
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
struct whisper_filters {
|
| 14 |
+
int32_t n_mel;
|
| 15 |
+
int32_t n_fft;
|
| 16 |
+
|
| 17 |
+
std::vector<float> data;
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
template <typename T>
|
| 21 |
+
struct whisper_span {
|
| 22 |
+
T * data;
|
| 23 |
+
int len;
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
struct whisper_mel_calc {
|
| 27 |
+
virtual ~whisper_mel_calc();
|
| 28 |
+
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) const = 0;
|
| 29 |
+
static whisper_span<const float> hann_window();
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
// returns a new pointer which needs to be freed with delete
|
| 33 |
+
whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters);
|
whisper.cpp
CHANGED
|
@@ -10,6 +10,7 @@
|
|
| 10 |
|
| 11 |
#ifdef GGML_USE_CUDA
|
| 12 |
#include "ggml-cuda.h"
|
|
|
|
| 13 |
#endif
|
| 14 |
|
| 15 |
#ifdef GGML_USE_SYCL
|
|
@@ -24,6 +25,8 @@
|
|
| 24 |
#include "ggml-alloc.h"
|
| 25 |
#include "ggml-backend.h"
|
| 26 |
|
|
|
|
|
|
|
| 27 |
#include <atomic>
|
| 28 |
#include <algorithm>
|
| 29 |
#include <cassert>
|
|
@@ -380,21 +383,6 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
| 380 |
|
| 381 |
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
| 382 |
|
| 383 |
-
struct whisper_mel {
|
| 384 |
-
int n_len;
|
| 385 |
-
int n_len_org;
|
| 386 |
-
int n_mel;
|
| 387 |
-
|
| 388 |
-
std::vector<float> data;
|
| 389 |
-
};
|
| 390 |
-
|
| 391 |
-
struct whisper_filters {
|
| 392 |
-
int32_t n_mel;
|
| 393 |
-
int32_t n_fft;
|
| 394 |
-
|
| 395 |
-
std::vector<float> data;
|
| 396 |
-
};
|
| 397 |
-
|
| 398 |
struct whisper_vocab {
|
| 399 |
using id = int32_t;
|
| 400 |
using token = std::string;
|
|
@@ -883,6 +871,8 @@ struct whisper_context {
|
|
| 883 |
whisper_model model;
|
| 884 |
whisper_vocab vocab;
|
| 885 |
|
|
|
|
|
|
|
| 886 |
whisper_state * state = nullptr;
|
| 887 |
|
| 888 |
ggml_backend_t backend = nullptr;
|
|
@@ -2894,6 +2884,14 @@ struct whisper_global_cache {
|
|
| 2894 |
} global_cache;
|
| 2895 |
}
|
| 2896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2897 |
// naive Discrete Fourier Transform
|
| 2898 |
// input is real-valued
|
| 2899 |
// output is complex-valued
|
|
@@ -2976,8 +2974,10 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2976 |
}
|
| 2977 |
|
| 2978 |
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
| 2979 |
-
int n_samples, int
|
| 2980 |
const whisper_filters & filters, whisper_mel & mel) {
|
|
|
|
|
|
|
| 2981 |
std::vector<float> fft_in(frame_size, 0.0);
|
| 2982 |
std::vector<float> fft_out(2 * frame_size);
|
| 2983 |
int n_fft = filters.n_fft;
|
|
@@ -3041,99 +3041,95 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
|
|
| 3041 |
}
|
| 3042 |
}
|
| 3043 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3044 |
|
| 3045 |
-
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
| 3046 |
-
|
| 3047 |
-
|
| 3048 |
-
|
| 3049 |
-
const int n_samples,
|
| 3050 |
-
const int /*sample_rate*/,
|
| 3051 |
-
const int frame_size,
|
| 3052 |
-
const int frame_step,
|
| 3053 |
-
const int n_mel,
|
| 3054 |
-
const int n_threads,
|
| 3055 |
-
const whisper_filters & filters,
|
| 3056 |
-
const bool debug,
|
| 3057 |
-
whisper_mel & mel) {
|
| 3058 |
-
const int64_t t_start_us = ggml_time_us();
|
| 3059 |
|
| 3060 |
-
|
| 3061 |
-
|
| 3062 |
-
|
| 3063 |
|
| 3064 |
-
|
| 3065 |
-
|
| 3066 |
-
int64_t stage_2_pad = frame_size / 2;
|
| 3067 |
|
| 3068 |
-
|
| 3069 |
-
|
| 3070 |
-
|
| 3071 |
-
|
| 3072 |
|
| 3073 |
-
|
| 3074 |
-
|
| 3075 |
|
| 3076 |
-
|
| 3077 |
-
|
| 3078 |
|
| 3079 |
-
|
| 3080 |
-
|
| 3081 |
-
|
| 3082 |
-
|
| 3083 |
-
|
| 3084 |
-
|
| 3085 |
-
|
|
|
|
| 3086 |
|
| 3087 |
|
| 3088 |
-
|
| 3089 |
-
|
| 3090 |
-
|
| 3091 |
-
|
| 3092 |
-
|
| 3093 |
-
|
| 3094 |
-
|
| 3095 |
-
|
| 3096 |
-
|
| 3097 |
-
// main thread
|
| 3098 |
-
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
| 3099 |
|
| 3100 |
-
|
| 3101 |
-
|
| 3102 |
-
}
|
| 3103 |
-
}
|
| 3104 |
|
| 3105 |
-
|
| 3106 |
-
|
| 3107 |
-
|
| 3108 |
-
if (mel.data[i] > mmax) {
|
| 3109 |
-
mmax = mel.data[i];
|
| 3110 |
}
|
| 3111 |
-
}
|
| 3112 |
-
|
| 3113 |
-
mmax -= 8.0;
|
| 3114 |
|
| 3115 |
-
|
| 3116 |
-
|
| 3117 |
-
|
|
|
|
|
|
|
|
|
|
| 3118 |
}
|
| 3119 |
|
| 3120 |
-
|
| 3121 |
-
}
|
| 3122 |
|
| 3123 |
-
|
|
|
|
|
|
|
|
|
|
| 3124 |
|
| 3125 |
-
|
| 3126 |
-
if (debug) {
|
| 3127 |
-
std::ofstream outFile("log_mel_spectrogram.json");
|
| 3128 |
-
outFile << "[";
|
| 3129 |
-
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
| 3130 |
-
outFile << mel.data[i] << ", ";
|
| 3131 |
}
|
| 3132 |
-
|
| 3133 |
-
|
| 3134 |
}
|
|
|
|
|
|
|
| 3135 |
|
| 3136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3137 |
}
|
| 3138 |
|
| 3139 |
// split text into tokens
|
|
@@ -3593,6 +3589,8 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|
| 3593 |
return nullptr;
|
| 3594 |
}
|
| 3595 |
|
|
|
|
|
|
|
| 3596 |
loader->close(loader->context);
|
| 3597 |
|
| 3598 |
return ctx;
|
|
@@ -3713,6 +3711,8 @@ void whisper_free(struct whisper_context * ctx) {
|
|
| 3713 |
|
| 3714 |
ggml_backend_free(ctx->backend);
|
| 3715 |
|
|
|
|
|
|
|
| 3716 |
delete ctx;
|
| 3717 |
}
|
| 3718 |
}
|
|
@@ -3730,11 +3730,21 @@ void whisper_free_params(struct whisper_full_params * params) {
|
|
| 3730 |
}
|
| 3731 |
|
| 3732 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3733 |
-
|
| 3734 |
-
|
| 3735 |
-
|
| 3736 |
-
}
|
| 3737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3738 |
return 0;
|
| 3739 |
}
|
| 3740 |
|
|
|
|
| 10 |
|
| 11 |
#ifdef GGML_USE_CUDA
|
| 12 |
#include "ggml-cuda.h"
|
| 13 |
+
#include "whisper-mel-cuda.hpp"
|
| 14 |
#endif
|
| 15 |
|
| 16 |
#ifdef GGML_USE_SYCL
|
|
|
|
| 25 |
#include "ggml-alloc.h"
|
| 26 |
#include "ggml-backend.h"
|
| 27 |
|
| 28 |
+
#include "whisper-mel.hpp"
|
| 29 |
+
|
| 30 |
#include <atomic>
|
| 31 |
#include <algorithm>
|
| 32 |
#include <cassert>
|
|
|
|
| 383 |
|
| 384 |
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
struct whisper_vocab {
|
| 387 |
using id = int32_t;
|
| 388 |
using token = std::string;
|
|
|
|
| 871 |
whisper_model model;
|
| 872 |
whisper_vocab vocab;
|
| 873 |
|
| 874 |
+
whisper_mel_calc * mel_calc = nullptr;
|
| 875 |
+
|
| 876 |
whisper_state * state = nullptr;
|
| 877 |
|
| 878 |
ggml_backend_t backend = nullptr;
|
|
|
|
| 2884 |
} global_cache;
|
| 2885 |
}
|
| 2886 |
|
| 2887 |
+
// Mel spectrogram
|
| 2888 |
+
|
| 2889 |
+
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
|
| 2890 |
+
|
| 2891 |
+
whisper_span<const float> whisper_mel_calc::hann_window() {
|
| 2892 |
+
return {global_cache.hann_window, WHISPER_N_FFT};
|
| 2893 |
+
}
|
| 2894 |
+
|
| 2895 |
// naive Discrete Fourier Transform
|
| 2896 |
// input is real-valued
|
| 2897 |
// output is complex-valued
|
|
|
|
| 2974 |
}
|
| 2975 |
|
| 2976 |
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
| 2977 |
+
int n_samples, int n_threads,
|
| 2978 |
const whisper_filters & filters, whisper_mel & mel) {
|
| 2979 |
+
const auto frame_size = WHISPER_N_FFT;
|
| 2980 |
+
const auto frame_step = WHISPER_HOP_LENGTH;
|
| 2981 |
std::vector<float> fft_in(frame_size, 0.0);
|
| 2982 |
std::vector<float> fft_out(2 * frame_size);
|
| 2983 |
int n_fft = filters.n_fft;
|
|
|
|
| 3041 |
}
|
| 3042 |
}
|
| 3043 |
}
|
| 3044 |
+
namespace {
|
| 3045 |
+
struct mel_calc_cpu : public whisper_mel_calc {
|
| 3046 |
+
const whisper_filters& m_filters;
|
| 3047 |
+
mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {}
|
| 3048 |
|
| 3049 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
| 3050 |
+
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
|
| 3051 |
+
// Hann window
|
| 3052 |
+
const float * hann = global_cache.hann_window;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3053 |
|
| 3054 |
+
// Calculate the length of padding
|
| 3055 |
+
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
| 3056 |
+
int64_t stage_2_pad = WHISPER_N_FFT / 2;
|
| 3057 |
|
| 3058 |
+
const int n_samples = int(ssamples.len);
|
| 3059 |
+
const float * samples = ssamples.data;
|
|
|
|
| 3060 |
|
| 3061 |
+
// Initialize a vector and copy data from C array to it.
|
| 3062 |
+
std::vector<float> samples_padded;
|
| 3063 |
+
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
| 3064 |
+
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
| 3065 |
|
| 3066 |
+
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
| 3067 |
+
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
| 3068 |
|
| 3069 |
+
// reflective pad 200 samples at the beginning of audio
|
| 3070 |
+
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
| 3071 |
|
| 3072 |
+
whisper_mel mel;
|
| 3073 |
+
mel.n_mel = m_filters.n_mel;
|
| 3074 |
+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
| 3075 |
+
// Calculate number of frames + remove the last frame
|
| 3076 |
+
mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 3077 |
+
// Calculate semi-padded sample length to ensure compatibility
|
| 3078 |
+
mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 3079 |
+
mel.data.resize(mel.n_mel * mel.n_len);
|
| 3080 |
|
| 3081 |
|
| 3082 |
+
{
|
| 3083 |
+
std::vector<std::thread> workers(n_threads - 1);
|
| 3084 |
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 3085 |
+
workers[iw] = std::thread(
|
| 3086 |
+
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
| 3087 |
+
n_samples + stage_2_pad, n_threads,
|
| 3088 |
+
std::cref(m_filters), std::ref(mel));
|
| 3089 |
+
}
|
|
|
|
|
|
|
|
|
|
| 3090 |
|
| 3091 |
+
// main thread
|
| 3092 |
+
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel);
|
|
|
|
|
|
|
| 3093 |
|
| 3094 |
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 3095 |
+
workers[iw].join();
|
| 3096 |
+
}
|
|
|
|
|
|
|
| 3097 |
}
|
|
|
|
|
|
|
|
|
|
| 3098 |
|
| 3099 |
+
// clamping and normalization
|
| 3100 |
+
double mmax = -1e20;
|
| 3101 |
+
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
| 3102 |
+
if (mel.data[i] > mmax) {
|
| 3103 |
+
mmax = mel.data[i];
|
| 3104 |
+
}
|
| 3105 |
}
|
| 3106 |
|
| 3107 |
+
mmax -= 8.0;
|
|
|
|
| 3108 |
|
| 3109 |
+
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
| 3110 |
+
if (mel.data[i] < mmax) {
|
| 3111 |
+
mel.data[i] = mmax;
|
| 3112 |
+
}
|
| 3113 |
|
| 3114 |
+
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3115 |
}
|
| 3116 |
+
|
| 3117 |
+
return mel;
|
| 3118 |
}
|
| 3119 |
+
};
|
| 3120 |
+
}
|
| 3121 |
|
| 3122 |
+
whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters) {
|
| 3123 |
+
#if GGML_USE_CUDA
|
| 3124 |
+
if (ggml_backend_is_cuda(backend)) {
|
| 3125 |
+
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
| 3126 |
+
// run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
|
| 3127 |
+
const float warmup[256] = {0};
|
| 3128 |
+
ret->calculate({warmup, 256}, 1);
|
| 3129 |
+
return ret;
|
| 3130 |
+
} else
|
| 3131 |
+
#endif
|
| 3132 |
+
return new mel_calc_cpu(filters);
|
| 3133 |
}
|
| 3134 |
|
| 3135 |
// split text into tokens
|
|
|
|
| 3589 |
return nullptr;
|
| 3590 |
}
|
| 3591 |
|
| 3592 |
+
ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters);
|
| 3593 |
+
|
| 3594 |
loader->close(loader->context);
|
| 3595 |
|
| 3596 |
return ctx;
|
|
|
|
| 3711 |
|
| 3712 |
ggml_backend_free(ctx->backend);
|
| 3713 |
|
| 3714 |
+
delete ctx->mel_calc;
|
| 3715 |
+
ctx->mel_calc = nullptr;
|
| 3716 |
delete ctx;
|
| 3717 |
}
|
| 3718 |
}
|
|
|
|
| 3730 |
}
|
| 3731 |
|
| 3732 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3733 |
+
const int64_t t_start_us = ggml_time_us();
|
| 3734 |
+
state->mel = ctx->mel_calc->calculate({samples, n_samples}, n_threads);
|
| 3735 |
+
state->t_mel_us += ggml_time_us() - t_start_us;
|
|
|
|
| 3736 |
|
| 3737 |
+
// Dump log_mel_spectrogram
|
| 3738 |
+
//{
|
| 3739 |
+
// auto& mel = state->mel;
|
| 3740 |
+
// std::ofstream outFile("log_mel_spectrogram.json");
|
| 3741 |
+
// outFile << "[";
|
| 3742 |
+
// for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
| 3743 |
+
// outFile << mel.data[i] << ", ";
|
| 3744 |
+
// }
|
| 3745 |
+
// outFile << mel.data[mel.data.size() - 1] << "]";
|
| 3746 |
+
// outFile.close();
|
| 3747 |
+
//}
|
| 3748 |
return 0;
|
| 3749 |
}
|
| 3750 |
|
whisper.h
CHANGED
|
@@ -31,8 +31,10 @@
|
|
| 31 |
|
| 32 |
#define WHISPER_SAMPLE_RATE 16000
|
| 33 |
#define WHISPER_N_FFT 400
|
|
|
|
| 34 |
#define WHISPER_HOP_LENGTH 160
|
| 35 |
#define WHISPER_CHUNK_SIZE 30
|
|
|
|
| 36 |
|
| 37 |
#ifdef __cplusplus
|
| 38 |
extern "C" {
|
|
|
|
| 31 |
|
| 32 |
#define WHISPER_SAMPLE_RATE 16000
|
| 33 |
#define WHISPER_N_FFT 400
|
| 34 |
+
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
|
| 35 |
#define WHISPER_HOP_LENGTH 160
|
| 36 |
#define WHISPER_CHUNK_SIZE 30
|
| 37 |
+
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
|
| 38 |
|
| 39 |
#ifdef __cplusplus
|
| 40 |
extern "C" {
|