Spaces:
Running
whisper : Metal and ggml-alloc support (#1270)
Browse files* metal : init
* whisper : factor out graph builds
* whisper : allocate encoder and decoder using ggml-alloc
* whisper : ggml-alloc is now supported
* whisper : CoreML support ggml-alloc
* build : fix ggml-alloc
* ios : update submodule
* extra : update sync-ggml.sh script to also sync ggml-alloc
* ci : see if this is causing the crash
* whisper : refactor ggml-alloc init
* whisper.android : try to fix build
* whisper : initial Metal version
* ci : try to debug vmem issue
* metal : decoder works on GPU!
* metal : add multi-decoder support
* ggml : fix ggml_nbytes (probably temp solution)
* metal : run "cross" step on the GPU
* whisper : remove ggml_repeat in the encoder
* whisper : offload the Encoder to Metal
* ggml : use simpler ggml_bytes() implementation
* ggml-alloc : try to make CI happy by reducing vram to 128GB
* whisper : add whisper_allocr to wrap ggml_allocr
* whisper : factor out alloc init in a function
* cmake : update to support Metal build
* whisper : add <functional> header
* objc : fix build (no Metal yet)
* ios : add Metal support
* swiftui : fix build
* metal : speed-up KQ multiplication
* metal : sync latest llama.cpp kernels
* readme : add Metal info
* ios : update submodule
* coreml : add code to toggle Core ML config (CPU, ANE, GPU)
* bench : fix timings by running a pre-heat
* bench : start benching the decoder
* whisper : add ggml_mul_mat_pad
* bench : fix uninitialized vars
* whisper : add comment for disabling mul-mat padding
* whisper : add description of ggml_mul_mat_pad
* whisper : clean-up ggml_mul_mat_pad
* metal : remove the "concurrent" flag
* bench : variable n_past
* ios : update SPM package
- CMakeLists.txt +57 -9
- Makefile +22 -1
- README.md +6 -2
- bindings/ios +1 -1
- coreml/whisper-encoder.mm +7 -1
- examples/bench/bench.cpp +40 -3
- examples/talk-llama/CMakeLists.txt +1 -1
- examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt +2 -1
- examples/whisper.objc/README.md +12 -0
- examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj +28 -1
- examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj +8 -2
- extra/bench-all.sh +9 -10
- extra/sync-ggml.sh +15 -13
- ggml-alloc.c +113 -74
- ggml-metal.m +83 -24
- ggml-metal.metal +326 -167
- ggml.c +17 -5
- whisper.cpp +790 -619
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
cmake_minimum_required (VERSION 3.
|
| 2 |
|
| 3 |
project(whisper.cpp VERSION 1.4.2)
|
| 4 |
|
|
@@ -35,6 +35,12 @@ endif()
|
|
| 35 |
|
| 36 |
# options
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
| 39 |
|
| 40 |
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
|
@@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
|
|
| 58 |
|
| 59 |
if (APPLE)
|
| 60 |
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
|
|
|
|
|
|
| 61 |
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
| 62 |
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
| 63 |
else()
|
|
@@ -113,6 +121,34 @@ if (APPLE)
|
|
| 113 |
endif()
|
| 114 |
endif()
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if (WHISPER_COREML)
|
| 117 |
find_library(FOUNDATION_FRAMEWORK Foundation)
|
| 118 |
find_library(COREML_FRAMEWORK CoreML)
|
|
@@ -177,7 +213,7 @@ if (WHISPER_CUBLAS)
|
|
| 177 |
|
| 178 |
enable_language(CUDA)
|
| 179 |
|
| 180 |
-
set(
|
| 181 |
|
| 182 |
add_compile_definitions(GGML_USE_CUBLAS)
|
| 183 |
|
|
@@ -228,7 +264,7 @@ if (WHISPER_CLBLAST)
|
|
| 228 |
if (CLBlast_FOUND)
|
| 229 |
message(STATUS "CLBlast found")
|
| 230 |
|
| 231 |
-
set(
|
| 232 |
|
| 233 |
add_compile_definitions(GGML_USE_CLBLAST)
|
| 234 |
|
|
@@ -426,8 +462,11 @@ set(TARGET whisper)
|
|
| 426 |
add_library(${TARGET}
|
| 427 |
ggml.h
|
| 428 |
ggml.c
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
| 431 |
whisper.h
|
| 432 |
whisper.cpp
|
| 433 |
)
|
|
@@ -468,9 +507,15 @@ if (BUILD_SHARED_LIBS)
|
|
| 468 |
WHISPER_BUILD
|
| 469 |
GGML_BUILD
|
| 470 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
endif()
|
| 472 |
|
| 473 |
-
if (
|
| 474 |
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
| 475 |
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
| 476 |
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
|
@@ -486,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC
|
|
| 486 |
|
| 487 |
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
| 488 |
|
|
|
|
|
|
|
| 489 |
install(TARGETS ${TARGET}
|
| 490 |
-
LIBRARY
|
| 491 |
-
ARCHIVE
|
| 492 |
-
RUNTIME
|
|
|
|
| 493 |
PUBLIC_HEADER DESTINATION include
|
| 494 |
)
|
| 495 |
|
|
|
|
| 1 |
+
cmake_minimum_required (VERSION 3.5)
|
| 2 |
|
| 3 |
project(whisper.cpp VERSION 1.4.2)
|
| 4 |
|
|
|
|
| 35 |
|
| 36 |
# options
|
| 37 |
|
| 38 |
+
if (APPLE)
|
| 39 |
+
set(WHISPER_METAL_DEFAULT ON)
|
| 40 |
+
else()
|
| 41 |
+
set(WHISPER_METAL_DEFAULT OFF)
|
| 42 |
+
endif()
|
| 43 |
+
|
| 44 |
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
| 45 |
|
| 46 |
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
|
|
|
| 64 |
|
| 65 |
if (APPLE)
|
| 66 |
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
| 67 |
+
option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT})
|
| 68 |
+
option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
|
| 69 |
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
| 70 |
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
| 71 |
else()
|
|
|
|
| 121 |
endif()
|
| 122 |
endif()
|
| 123 |
|
| 124 |
+
if (WHISPER_METAL)
|
| 125 |
+
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
| 126 |
+
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
| 127 |
+
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
| 128 |
+
|
| 129 |
+
if (METAL_FRAMEWORK)
|
| 130 |
+
message(STATUS "Metal framework found")
|
| 131 |
+
|
| 132 |
+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
|
| 133 |
+
${FOUNDATION_LIBRARY}
|
| 134 |
+
${METAL_FRAMEWORK}
|
| 135 |
+
${METALKIT_FRAMEWORK}
|
| 136 |
+
)
|
| 137 |
+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL)
|
| 138 |
+
|
| 139 |
+
if (WHISPER_METAL_NDEBUG)
|
| 140 |
+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
|
| 141 |
+
endif()
|
| 142 |
+
else()
|
| 143 |
+
message(WARNING "Metal framework not found")
|
| 144 |
+
endif()
|
| 145 |
+
|
| 146 |
+
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
| 147 |
+
|
| 148 |
+
# copy ggml-metal.metal to bin directory
|
| 149 |
+
configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
|
| 150 |
+
endif()
|
| 151 |
+
|
| 152 |
if (WHISPER_COREML)
|
| 153 |
find_library(FOUNDATION_FRAMEWORK Foundation)
|
| 154 |
find_library(COREML_FRAMEWORK CoreML)
|
|
|
|
| 213 |
|
| 214 |
enable_language(CUDA)
|
| 215 |
|
| 216 |
+
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
| 217 |
|
| 218 |
add_compile_definitions(GGML_USE_CUBLAS)
|
| 219 |
|
|
|
|
| 264 |
if (CLBlast_FOUND)
|
| 265 |
message(STATUS "CLBlast found")
|
| 266 |
|
| 267 |
+
set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
|
| 268 |
|
| 269 |
add_compile_definitions(GGML_USE_CLBLAST)
|
| 270 |
|
|
|
|
| 462 |
add_library(${TARGET}
|
| 463 |
ggml.h
|
| 464 |
ggml.c
|
| 465 |
+
ggml-alloc.h
|
| 466 |
+
ggml-alloc.c
|
| 467 |
+
${GGML_SOURCES_METAL}
|
| 468 |
+
${GGML_SOURCES_CUDA}
|
| 469 |
+
${GGML_SOURCES_OPENCL}
|
| 470 |
whisper.h
|
| 471 |
whisper.cpp
|
| 472 |
)
|
|
|
|
| 507 |
WHISPER_BUILD
|
| 508 |
GGML_BUILD
|
| 509 |
)
|
| 510 |
+
|
| 511 |
+
if (WHISPER_METAL)
|
| 512 |
+
# TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory
|
| 513 |
+
# but for some reason it does not work here like it does in llama.cpp
|
| 514 |
+
set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
| 515 |
+
endif()
|
| 516 |
endif()
|
| 517 |
|
| 518 |
+
if (GGML_SOURCES_CUDA)
|
| 519 |
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
| 520 |
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
| 521 |
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
|
|
|
| 531 |
|
| 532 |
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
| 533 |
|
| 534 |
+
include(GNUInstallDirs)
|
| 535 |
+
|
| 536 |
install(TARGETS ${TARGET}
|
| 537 |
+
LIBRARY DESTINATION lib
|
| 538 |
+
ARCHIVE DESTINATION lib/static
|
| 539 |
+
RUNTIME DESTINATION bin
|
| 540 |
+
RESOURCE DESTINATION bin
|
| 541 |
PUBLIC_HEADER DESTINATION include
|
| 542 |
)
|
| 543 |
|
|
@@ -18,7 +18,7 @@ ifndef NVCC_VERSION
|
|
| 18 |
endif
|
| 19 |
endif
|
| 20 |
|
| 21 |
-
CCV
|
| 22 |
CXXV := $(shell $(CXX) --version | head -n 1)
|
| 23 |
|
| 24 |
# Mac OS + Arm can report x86_64
|
|
@@ -182,6 +182,15 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK
|
|
| 182 |
endif
|
| 183 |
endif
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
ifdef WHISPER_OPENBLAS
|
| 186 |
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
|
| 187 |
LDFLAGS += -lopenblas
|
|
@@ -288,6 +297,11 @@ $(info )
|
|
| 288 |
ggml.o: ggml.c ggml.h ggml-cuda.h
|
| 289 |
$(CC) $(CFLAGS) -c $< -o $@
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
| 292 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
| 293 |
|
|
@@ -303,6 +317,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp
|
|
| 303 |
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
| 304 |
endif
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
| 307 |
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
| 308 |
|
|
|
|
| 18 |
endif
|
| 19 |
endif
|
| 20 |
|
| 21 |
+
CCV := $(shell $(CC) --version | head -n 1)
|
| 22 |
CXXV := $(shell $(CXX) --version | head -n 1)
|
| 23 |
|
| 24 |
# Mac OS + Arm can report x86_64
|
|
|
|
| 182 |
endif
|
| 183 |
endif
|
| 184 |
|
| 185 |
+
ifndef WHISPER_NO_METAL
|
| 186 |
+
ifeq ($(UNAME_S),Darwin)
|
| 187 |
+
WHISPER_METAL := 1
|
| 188 |
+
|
| 189 |
+
CXXFLAGS += -DGGML_USE_METAL
|
| 190 |
+
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
| 191 |
+
endif
|
| 192 |
+
endif
|
| 193 |
+
|
| 194 |
ifdef WHISPER_OPENBLAS
|
| 195 |
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
|
| 196 |
LDFLAGS += -lopenblas
|
|
|
|
| 297 |
ggml.o: ggml.c ggml.h ggml-cuda.h
|
| 298 |
$(CC) $(CFLAGS) -c $< -o $@
|
| 299 |
|
| 300 |
+
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
|
| 301 |
+
$(CC) $(CFLAGS) -c $< -o $@
|
| 302 |
+
|
| 303 |
+
WHISPER_OBJ += ggml-alloc.o
|
| 304 |
+
|
| 305 |
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
| 306 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
| 307 |
|
|
|
|
| 317 |
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
| 318 |
endif
|
| 319 |
|
| 320 |
+
ifdef WHISPER_METAL
|
| 321 |
+
ggml-metal.o: ggml-metal.m ggml-metal.h
|
| 322 |
+
$(CC) $(CFLAGS) -c $< -o $@
|
| 323 |
+
|
| 324 |
+
WHISPER_OBJ += ggml-metal.o
|
| 325 |
+
endif
|
| 326 |
+
|
| 327 |
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
| 328 |
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
| 329 |
|
|
@@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S
|
|
| 11 |
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
| 12 |
|
| 13 |
- Plain C/C++ implementation without dependencies
|
| 14 |
-
- Apple
|
| 15 |
- AVX intrinsics support for x86 architectures
|
| 16 |
- VSX intrinsics support for POWER architectures
|
| 17 |
- Mixed F16 / F32 precision
|
| 18 |
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
| 19 |
- Low memory usage (Flash Attention)
|
| 20 |
- Zero memory allocations at runtime
|
| 21 |
-
-
|
| 22 |
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
| 23 |
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
| 24 |
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
|
@@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command]
|
|
| 50 |
|
| 51 |
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
| 54 |
|
| 55 |
## Implementation details
|
|
|
|
| 11 |
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
| 12 |
|
| 13 |
- Plain C/C++ implementation without dependencies
|
| 14 |
+
- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
| 15 |
- AVX intrinsics support for x86 architectures
|
| 16 |
- VSX intrinsics support for POWER architectures
|
| 17 |
- Mixed F16 / F32 precision
|
| 18 |
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
| 19 |
- Low memory usage (Flash Attention)
|
| 20 |
- Zero memory allocations at runtime
|
| 21 |
+
- Support for CPU-only inference
|
| 22 |
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
| 23 |
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
| 24 |
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
|
|
|
| 50 |
|
| 51 |
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
| 52 |
|
| 53 |
+
On Apply Silicon, the inference runs fully on the GPU via Metal:
|
| 54 |
+
|
| 55 |
+
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
| 56 |
+
|
| 57 |
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
| 58 |
|
| 59 |
## Implementation details
|
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 22a9eef021afc67f2154bc9811ed620b26299d1b
|
|
@@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
|
| 22 |
|
| 23 |
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
if (data == NULL) {
|
| 28 |
return NULL;
|
|
|
|
| 22 |
|
| 23 |
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
| 24 |
|
| 25 |
+
// select which device to run the Core ML model on
|
| 26 |
+
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
|
| 27 |
+
config.computeUnits = MLComputeUnitsCPUAndGPU;
|
| 28 |
+
//config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
|
| 29 |
+
//config.computeUnits = MLComputeUnitsAll;
|
| 30 |
+
|
| 31 |
+
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
|
| 32 |
|
| 33 |
if (data == NULL) {
|
| 34 |
return NULL;
|
|
@@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 44 |
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
| 45 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 46 |
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
| 47 |
-
fprintf(stderr, " %-7s 0 - whisper
|
| 48 |
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
| 49 |
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
| 50 |
fprintf(stderr, "\n");
|
| 51 |
}
|
| 52 |
|
| 53 |
-
int
|
| 54 |
// whisper init
|
| 55 |
|
| 56 |
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
|
@@ -69,12 +69,49 @@ int whisper_bench_encoder(const whisper_params & params) {
|
|
| 69 |
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
| 70 |
return 3;
|
| 71 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 74 |
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 75 |
return 4;
|
| 76 |
}
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
whisper_print_timings(ctx);
|
| 79 |
whisper_free(ctx);
|
| 80 |
|
|
@@ -103,7 +140,7 @@ int main(int argc, char ** argv) {
|
|
| 103 |
int ret = -1;
|
| 104 |
|
| 105 |
switch (params.what) {
|
| 106 |
-
case 0: ret =
|
| 107 |
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
| 108 |
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
| 109 |
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
|
|
|
| 44 |
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
| 45 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 46 |
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
| 47 |
+
fprintf(stderr, " %-7s 0 - whisper\n", "");
|
| 48 |
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
| 49 |
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
| 50 |
fprintf(stderr, "\n");
|
| 51 |
}
|
| 52 |
|
| 53 |
+
int whisper_bench_full(const whisper_params & params) {
|
| 54 |
// whisper init
|
| 55 |
|
| 56 |
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
|
|
|
| 69 |
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
| 70 |
return 3;
|
| 71 |
}
|
| 72 |
+
// heat encoder
|
| 73 |
+
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 74 |
+
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 75 |
+
return 4;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
whisper_token tokens[512];
|
| 79 |
+
memset(tokens, 0, sizeof(tokens));
|
| 80 |
+
|
| 81 |
+
// prompt heat
|
| 82 |
+
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
| 83 |
+
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 84 |
+
return 4;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// text-generation heat
|
| 88 |
+
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
| 89 |
+
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 90 |
+
return 4;
|
| 91 |
+
}
|
| 92 |
|
| 93 |
+
whisper_reset_timings(ctx);
|
| 94 |
+
|
| 95 |
+
// actual run
|
| 96 |
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
| 97 |
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 98 |
return 4;
|
| 99 |
}
|
| 100 |
|
| 101 |
+
for (int i = 0; i < 16; i++) {
|
| 102 |
+
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
| 103 |
+
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 104 |
+
return 4;
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
for (int i = 0; i < 256; i++) {
|
| 109 |
+
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
| 110 |
+
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
| 111 |
+
return 4;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
whisper_print_timings(ctx);
|
| 116 |
whisper_free(ctx);
|
| 117 |
|
|
|
|
| 140 |
int ret = -1;
|
| 141 |
|
| 142 |
switch (params.what) {
|
| 143 |
+
case 0: ret = whisper_bench_full(params); break;
|
| 144 |
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
| 145 |
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
| 146 |
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
|
@@ -7,7 +7,7 @@ if (WHISPER_SDL2)
|
|
| 7 |
|
| 8 |
# TODO: this is temporary
|
| 9 |
# need to export ggml symbols for MSVC, but too lazy ..
|
| 10 |
-
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
| 11 |
|
| 12 |
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
| 13 |
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
|
|
|
| 7 |
|
| 8 |
# TODO: this is temporary
|
| 9 |
# need to export ggml symbols for MSVC, but too lazy ..
|
| 10 |
+
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
|
| 11 |
|
| 12 |
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
| 13 |
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
|
@@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
|
| 8 |
set(
|
| 9 |
SOURCE_FILES
|
| 10 |
${WHISPER_LIB_DIR}/ggml.c
|
|
|
|
| 11 |
${WHISPER_LIB_DIR}/whisper.cpp
|
| 12 |
${CMAKE_SOURCE_DIR}/jni.c
|
| 13 |
)
|
|
@@ -20,7 +21,7 @@ function(build_library target_name)
|
|
| 20 |
SHARED
|
| 21 |
${SOURCE_FILES}
|
| 22 |
)
|
| 23 |
-
|
| 24 |
target_link_libraries(${target_name} ${LOG_LIB} android)
|
| 25 |
|
| 26 |
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
|
|
|
| 8 |
set(
|
| 9 |
SOURCE_FILES
|
| 10 |
${WHISPER_LIB_DIR}/ggml.c
|
| 11 |
+
${WHISPER_LIB_DIR}/ggml-alloc.c
|
| 12 |
${WHISPER_LIB_DIR}/whisper.cpp
|
| 13 |
${CMAKE_SOURCE_DIR}/jni.c
|
| 14 |
)
|
|
|
|
| 21 |
SHARED
|
| 22 |
${SOURCE_FILES}
|
| 23 |
)
|
| 24 |
+
|
| 25 |
target_link_libraries(${target_name} ${LOG_LIB} android)
|
| 26 |
|
| 27 |
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
|
@@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription:
|
|
| 28 |
|
| 29 |
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
| 30 |
|
|
|
|
|
|
|
| 31 |
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
|
| 32 |
|
| 33 |
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
|
|
@@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW
|
|
| 35 |
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
|
| 36 |
|
| 37 |
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
| 30 |
|
| 31 |
+
## Core ML
|
| 32 |
+
|
| 33 |
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
|
| 34 |
|
| 35 |
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
|
|
|
|
| 37 |
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
|
| 38 |
|
| 39 |
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
| 40 |
+
|
| 41 |
+
## Metal
|
| 42 |
+
|
| 43 |
+
You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient
|
| 44 |
+
compared to Core ML depending on the model and device that you use.
|
| 45 |
+
|
| 46 |
+
To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready.
|
| 47 |
+
This will make both the Encoder and the Decoder run on the GPU.
|
| 48 |
+
|
| 49 |
+
If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all!
|
|
@@ -7,6 +7,9 @@
|
|
| 7 |
objects = {
|
| 8 |
|
| 9 |
/* Begin PBXBuildFile section */
|
|
|
|
|
|
|
|
|
|
| 10 |
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
| 11 |
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
|
| 12 |
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
|
|
@@ -14,7 +17,7 @@
|
|
| 14 |
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
|
| 15 |
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
|
| 16 |
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
|
| 17 |
-
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML
|
| 18 |
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
|
| 19 |
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
|
| 20 |
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
|
@@ -23,7 +26,24 @@
|
|
| 23 |
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
|
| 24 |
/* End PBXBuildFile section */
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
/* Begin PBXFileReference section */
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
| 28 |
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
|
| 29 |
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
|
|
@@ -80,6 +100,10 @@
|
|
| 80 |
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
| 81 |
isa = PBXGroup;
|
| 82 |
children = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
|
| 84 |
7FE342442A0C3FA20015A058 /* coreml */,
|
| 85 |
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
|
|
@@ -126,6 +150,7 @@
|
|
| 126 |
18627C7229052BDF00BD2A04 /* Sources */,
|
| 127 |
18627C7329052BDF00BD2A04 /* Frameworks */,
|
| 128 |
18627C7429052BDF00BD2A04 /* Resources */,
|
|
|
|
| 129 |
);
|
| 130 |
buildRules = (
|
| 131 |
);
|
|
@@ -194,8 +219,10 @@
|
|
| 194 |
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
| 195 |
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
| 196 |
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
|
|
|
| 197 |
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
| 198 |
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
|
|
|
| 199 |
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
| 200 |
);
|
| 201 |
runOnlyForDeploymentPostprocessing = 0;
|
|
|
|
| 7 |
objects = {
|
| 8 |
|
| 9 |
/* Begin PBXBuildFile section */
|
| 10 |
+
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
|
| 11 |
+
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
|
| 12 |
+
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
|
| 13 |
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
| 14 |
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
|
| 15 |
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
|
|
|
|
| 17 |
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
|
| 18 |
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
|
| 19 |
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
|
| 20 |
+
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
|
| 21 |
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
|
| 22 |
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
|
| 23 |
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
|
|
|
| 26 |
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
|
| 27 |
/* End PBXBuildFile section */
|
| 28 |
|
| 29 |
+
/* Begin PBXCopyFilesBuildPhase section */
|
| 30 |
+
184447202AB21B25007D6BFE /* CopyFiles */ = {
|
| 31 |
+
isa = PBXCopyFilesBuildPhase;
|
| 32 |
+
buildActionMask = 2147483647;
|
| 33 |
+
dstPath = "";
|
| 34 |
+
dstSubfolderSpec = 7;
|
| 35 |
+
files = (
|
| 36 |
+
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */,
|
| 37 |
+
);
|
| 38 |
+
runOnlyForDeploymentPostprocessing = 0;
|
| 39 |
+
};
|
| 40 |
+
/* End PBXCopyFilesBuildPhase section */
|
| 41 |
+
|
| 42 |
/* Begin PBXFileReference section */
|
| 43 |
+
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = "<group>"; };
|
| 44 |
+
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = "<group>"; };
|
| 45 |
+
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = "<group>"; };
|
| 46 |
+
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = "<group>"; };
|
| 47 |
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
| 48 |
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
|
| 49 |
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
|
|
|
|
| 100 |
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
| 101 |
isa = PBXGroup;
|
| 102 |
children = (
|
| 103 |
+
1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
|
| 104 |
+
1844471B2AB21655007D6BFE /* ggml-metal.m */,
|
| 105 |
+
184447182AB211A2007D6BFE /* ggml-alloc.c */,
|
| 106 |
+
184447192AB211A2007D6BFE /* ggml-alloc.h */,
|
| 107 |
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
|
| 108 |
7FE342442A0C3FA20015A058 /* coreml */,
|
| 109 |
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
|
|
|
|
| 150 |
18627C7229052BDF00BD2A04 /* Sources */,
|
| 151 |
18627C7329052BDF00BD2A04 /* Frameworks */,
|
| 152 |
18627C7429052BDF00BD2A04 /* Resources */,
|
| 153 |
+
184447202AB21B25007D6BFE /* CopyFiles */,
|
| 154 |
);
|
| 155 |
buildRules = (
|
| 156 |
);
|
|
|
|
| 219 |
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
| 220 |
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
| 221 |
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
| 222 |
+
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
|
| 223 |
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
| 224 |
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
| 225 |
+
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
|
| 226 |
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
| 227 |
);
|
| 228 |
runOnlyForDeploymentPostprocessing = 0;
|
|
@@ -20,6 +20,7 @@
|
|
| 20 |
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
|
| 21 |
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
| 22 |
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
|
|
|
| 23 |
/* End PBXBuildFile section */
|
| 24 |
|
| 25 |
/* Begin PBXFileReference section */
|
|
@@ -41,6 +42,8 @@
|
|
| 41 |
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
| 42 |
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
| 43 |
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
|
|
|
|
|
|
| 44 |
/* End PBXFileReference section */
|
| 45 |
|
| 46 |
/* Begin PBXFrameworksBuildPhase section */
|
|
@@ -124,6 +127,8 @@
|
|
| 124 |
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
| 125 |
isa = PBXGroup;
|
| 126 |
children = (
|
|
|
|
|
|
|
| 127 |
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
| 128 |
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
| 129 |
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
|
@@ -242,6 +247,7 @@
|
|
| 242 |
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
| 243 |
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
| 244 |
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
|
|
|
| 245 |
);
|
| 246 |
runOnlyForDeploymentPostprocessing = 0;
|
| 247 |
};
|
|
@@ -369,7 +375,7 @@
|
|
| 369 |
CODE_SIGN_STYLE = Automatic;
|
| 370 |
CURRENT_PROJECT_VERSION = 1;
|
| 371 |
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
| 372 |
-
DEVELOPMENT_TEAM =
|
| 373 |
ENABLE_HARDENED_RUNTIME = YES;
|
| 374 |
ENABLE_PREVIEWS = YES;
|
| 375 |
GENERATE_INFOPLIST_FILE = YES;
|
|
@@ -410,7 +416,7 @@
|
|
| 410 |
CODE_SIGN_STYLE = Automatic;
|
| 411 |
CURRENT_PROJECT_VERSION = 1;
|
| 412 |
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
| 413 |
-
DEVELOPMENT_TEAM =
|
| 414 |
ENABLE_HARDENED_RUNTIME = YES;
|
| 415 |
ENABLE_PREVIEWS = YES;
|
| 416 |
GENERATE_INFOPLIST_FILE = YES;
|
|
|
|
| 20 |
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
|
| 21 |
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
| 22 |
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
| 23 |
+
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
|
| 24 |
/* End PBXBuildFile section */
|
| 25 |
|
| 26 |
/* Begin PBXFileReference section */
|
|
|
|
| 42 |
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
| 43 |
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
| 44 |
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
| 45 |
+
18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
|
| 46 |
+
18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
|
| 47 |
/* End PBXFileReference section */
|
| 48 |
|
| 49 |
/* Begin PBXFrameworksBuildPhase section */
|
|
|
|
| 127 |
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
| 128 |
isa = PBXGroup;
|
| 129 |
children = (
|
| 130 |
+
18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
|
| 131 |
+
18AED4802AB21F2B009D854F /* ggml-alloc.h */,
|
| 132 |
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
| 133 |
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
| 134 |
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
|
|
|
| 247 |
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
| 248 |
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
| 249 |
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
| 250 |
+
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
|
| 251 |
);
|
| 252 |
runOnlyForDeploymentPostprocessing = 0;
|
| 253 |
};
|
|
|
|
| 375 |
CODE_SIGN_STYLE = Automatic;
|
| 376 |
CURRENT_PROJECT_VERSION = 1;
|
| 377 |
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
| 378 |
+
DEVELOPMENT_TEAM = P8JZH34X63;
|
| 379 |
ENABLE_HARDENED_RUNTIME = YES;
|
| 380 |
ENABLE_PREVIEWS = YES;
|
| 381 |
GENERATE_INFOPLIST_FILE = YES;
|
|
|
|
| 416 |
CODE_SIGN_STYLE = Automatic;
|
| 417 |
CURRENT_PROJECT_VERSION = 1;
|
| 418 |
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
| 419 |
+
DEVELOPMENT_TEAM = P8JZH34X63;
|
| 420 |
ENABLE_HARDENED_RUNTIME = YES;
|
| 421 |
ENABLE_PREVIEWS = YES;
|
| 422 |
GENERATE_INFOPLIST_FILE = YES;
|
|
@@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then
|
|
| 44 |
printf "\n"
|
| 45 |
fi
|
| 46 |
|
| 47 |
-
printf "|
|
| 48 |
-
printf "|
|
| 49 |
|
| 50 |
for model in "${models[@]}"; do
|
| 51 |
-
# run once to heat-up the cache
|
| 52 |
-
./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null
|
| 53 |
-
|
| 54 |
# actual run
|
| 55 |
# store stderr output in a variable in order to parse it later
|
| 56 |
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
| 57 |
ret=$?
|
| 58 |
|
| 59 |
# parse the output:
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
system_info=$(echo "$output" | grep "system_info")
|
| 63 |
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
| 64 |
|
| 65 |
# floor to milliseconds
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
config=""
|
| 70 |
|
|
@@ -87,6 +86,6 @@ for model in "${models[@]}"; do
|
|
| 87 |
commit=$(git rev-parse --short HEAD)
|
| 88 |
|
| 89 |
if [ $ret -eq 0 ]; then
|
| 90 |
-
printf "| <todo> | <todo> |
|
| 91 |
fi
|
| 92 |
done
|
|
|
|
| 44 |
printf "\n"
|
| 45 |
fi
|
| 46 |
|
| 47 |
+
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
| 48 |
+
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
| 49 |
|
| 50 |
for model in "${models[@]}"; do
|
|
|
|
|
|
|
|
|
|
| 51 |
# actual run
|
| 52 |
# store stderr output in a variable in order to parse it later
|
| 53 |
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
| 54 |
ret=$?
|
| 55 |
|
| 56 |
# parse the output:
|
| 57 |
+
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
| 58 |
+
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
| 59 |
+
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
| 60 |
system_info=$(echo "$output" | grep "system_info")
|
| 61 |
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
| 62 |
|
| 63 |
# floor to milliseconds
|
| 64 |
+
#encode_time=${encode_time%.*}
|
| 65 |
+
#decode_time=${decode_time%.*}
|
| 66 |
+
#prompt_time=${prompt_time%.*}
|
| 67 |
|
| 68 |
config=""
|
| 69 |
|
|
|
|
| 86 |
commit=$(git rev-parse --short HEAD)
|
| 87 |
|
| 88 |
if [ $ret -eq 0 ]; then
|
| 89 |
+
printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
| 90 |
fi
|
| 91 |
done
|
|
@@ -1,18 +1,20 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
-
cp -rpv ../ggml/src/ggml.c
|
| 4 |
-
cp -rpv ../ggml/src/ggml-
|
| 5 |
-
cp -rpv ../ggml/src/ggml-cuda.
|
| 6 |
-
cp -rpv ../ggml/src/ggml-
|
| 7 |
-
cp -rpv ../ggml/src/ggml-opencl.
|
| 8 |
-
cp -rpv ../ggml/src/ggml-
|
| 9 |
-
cp -rpv ../ggml/src/ggml-metal.
|
| 10 |
-
cp -rpv ../ggml/src/ggml-metal.
|
| 11 |
-
cp -rpv ../ggml/
|
| 12 |
-
cp -rpv ../ggml/
|
| 13 |
-
cp -rpv ../ggml/
|
| 14 |
-
cp -rpv ../ggml/examples/common
|
| 15 |
-
cp -rpv ../ggml/examples/common
|
|
|
|
|
|
|
| 16 |
|
| 17 |
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
|
| 18 |
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
+
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
| 4 |
+
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
|
| 5 |
+
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
| 6 |
+
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
| 7 |
+
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
| 8 |
+
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
|
| 9 |
+
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
| 10 |
+
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
| 11 |
+
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
| 12 |
+
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
| 13 |
+
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
|
| 14 |
+
cp -rpv ../ggml/examples/common.h ./examples/common.h
|
| 15 |
+
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
|
| 16 |
+
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
|
| 17 |
+
cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
|
| 18 |
|
| 19 |
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
|
| 20 |
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
|
|
@@ -6,6 +6,26 @@
|
|
| 6 |
#include <stdlib.h>
|
| 7 |
#include <string.h>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
#define UNUSED(x) (void)(x)
|
| 10 |
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 11 |
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
|
@@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
|
|
| 99 |
}
|
| 100 |
#endif
|
| 101 |
|
| 102 |
-
|
| 103 |
-
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 104 |
return ggml_nbytes(tensor);
|
| 105 |
|
| 106 |
UNUSED(alloc);
|
| 107 |
}
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
size = aligned_offset(NULL, size, alloc->alignment);
|
| 112 |
|
| 113 |
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
|
@@ -131,14 +164,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
|
| 131 |
if (best_fit_block == -1) {
|
| 132 |
// the last block is our last resort
|
| 133 |
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
|
|
|
| 134 |
if (block->size >= size) {
|
| 135 |
best_fit_block = alloc->n_free_blocks - 1;
|
| 136 |
-
max_avail = MAX(max_avail, block->size);
|
| 137 |
} else {
|
| 138 |
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
| 139 |
__func__, size, max_avail);
|
| 140 |
GGML_ASSERT(!"not enough space in the buffer");
|
| 141 |
-
|
| 142 |
}
|
| 143 |
}
|
| 144 |
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
|
@@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
|
| 173 |
}
|
| 174 |
|
| 175 |
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
| 176 |
-
static void
|
| 177 |
void * ptr = tensor->data;
|
| 178 |
|
| 179 |
-
if (
|
| 180 |
// the tensor was not allocated in this buffer
|
| 181 |
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
| 182 |
// the easiest way to deal with this is just to ignore it
|
| 183 |
return;
|
| 184 |
}
|
| 185 |
|
| 186 |
-
size_t size =
|
| 187 |
size = aligned_offset(NULL, size, alloc->alignment);
|
| 188 |
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
| 189 |
|
|
@@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
|
|
| 277 |
return alloc;
|
| 278 |
}
|
| 279 |
|
| 280 |
-
//
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
| 286 |
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
*alloc = (struct ggml_allocr){
|
| 289 |
-
/*.data = */
|
| 290 |
-
/*.size = */
|
| 291 |
/*.alignment = */ alignment,
|
| 292 |
/*.n_free_blocks = */ 0,
|
| 293 |
/*.free_blocks = */ {{0}},
|
|
@@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
|
| 307 |
}
|
| 308 |
|
| 309 |
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
|
|
|
|
|
|
|
|
|
| 310 |
free(alloc);
|
| 311 |
}
|
| 312 |
|
|
@@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
|
|
| 316 |
|
| 317 |
//////////// compute graph allocator
|
| 318 |
|
| 319 |
-
static bool ggml_is_view(struct ggml_tensor * t) {
|
| 320 |
-
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
| 321 |
-
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 325 |
if (a->type != b->type) {
|
| 326 |
return false;
|
|
@@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
|
| 336 |
return true;
|
| 337 |
}
|
| 338 |
|
| 339 |
-
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
| 340 |
-
switch (t->op) {
|
| 341 |
-
case GGML_OP_PERMUTE:
|
| 342 |
-
case GGML_OP_RESHAPE:
|
| 343 |
-
case GGML_OP_TRANSPOSE:
|
| 344 |
-
case GGML_OP_VIEW:
|
| 345 |
-
return t->src[0];
|
| 346 |
-
case GGML_OP_CPY:
|
| 347 |
-
return t->src[1];
|
| 348 |
-
default:
|
| 349 |
-
return NULL;
|
| 350 |
-
}
|
| 351 |
-
}
|
| 352 |
-
|
| 353 |
-
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
| 354 |
-
struct ggml_tensor * parent = t;
|
| 355 |
-
do {
|
| 356 |
-
parent = get_view_parent(parent);
|
| 357 |
-
} while (ggml_is_view(parent));
|
| 358 |
-
return parent;
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
static bool ggml_op_can_inplace(enum ggml_op op) {
|
| 362 |
switch (op) {
|
| 363 |
case GGML_OP_SCALE:
|
|
@@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|
| 365 |
case GGML_OP_DIAG_MASK_INF:
|
| 366 |
case GGML_OP_ADD:
|
| 367 |
case GGML_OP_ADD1:
|
| 368 |
-
case GGML_OP_ACC:
|
| 369 |
case GGML_OP_SUB:
|
| 370 |
case GGML_OP_MUL:
|
| 371 |
case GGML_OP_DIV:
|
|
@@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|
| 375 |
case GGML_OP_UNARY:
|
| 376 |
case GGML_OP_ROPE:
|
| 377 |
case GGML_OP_RMS_NORM:
|
| 378 |
-
case GGML_OP_SET:
|
| 379 |
case GGML_OP_SOFT_MAX:
|
| 380 |
case GGML_OP_CONT:
|
| 381 |
-
case GGML_OP_ADD_REL_POS:
|
| 382 |
return true;
|
| 383 |
|
| 384 |
default:
|
|
@@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|
| 390 |
struct hash_node * ht = alloc->hash_table;
|
| 391 |
if (node->data == NULL) {
|
| 392 |
if (ggml_is_view(node)) {
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
case GGML_OP_VIEW:
|
| 396 |
-
memcpy(&offset, node->op_params, sizeof(size_t));
|
| 397 |
-
node->data = (char *) node->src[0]->data + offset;
|
| 398 |
-
break;
|
| 399 |
-
case GGML_OP_PERMUTE:
|
| 400 |
-
case GGML_OP_RESHAPE:
|
| 401 |
-
case GGML_OP_TRANSPOSE:
|
| 402 |
-
node->data = node->src[0]->data;
|
| 403 |
-
break;
|
| 404 |
-
case GGML_OP_CPY:
|
| 405 |
-
node->data = node->src[1]->data;
|
| 406 |
-
break;
|
| 407 |
-
default:
|
| 408 |
-
GGML_ASSERT(!"unknown view op");
|
| 409 |
-
break;
|
| 410 |
-
}
|
| 411 |
} else {
|
| 412 |
// see if we can reuse a parent's buffer (inplace)
|
| 413 |
if (ggml_op_can_inplace(node->op)) {
|
|
@@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|
| 418 |
}
|
| 419 |
|
| 420 |
// if the node's data is external, then we cannot re-use it
|
| 421 |
-
if ((
|
| 422 |
-
(char *) parent->data >= ((char *) alloc->data + alloc->size)) {
|
| 423 |
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
| 424 |
continue;
|
| 425 |
}
|
|
@@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|
| 427 |
struct hash_node * p_hn = hash_get(ht, parent);
|
| 428 |
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
| 429 |
if (ggml_is_view(parent)) {
|
| 430 |
-
struct ggml_tensor * view_src =
|
| 431 |
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 432 |
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
| 433 |
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
|
@@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|
| 453 |
}
|
| 454 |
}
|
| 455 |
|
| 456 |
-
static size_t
|
| 457 |
struct ggml_allocr * alloc,
|
| 458 |
struct ggml_cgraph ** graphs, int n_graphs,
|
| 459 |
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
|
@@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|
| 469 |
struct ggml_tensor * node = gf->nodes[i];
|
| 470 |
|
| 471 |
if (ggml_is_view(node)) {
|
| 472 |
-
struct ggml_tensor * view_src =
|
| 473 |
hash_get(ht, view_src)->n_views += 1;
|
| 474 |
}
|
| 475 |
|
|
@@ -531,11 +571,10 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|
| 531 |
AT_PRINTF("\n");
|
| 532 |
}
|
| 533 |
|
| 534 |
-
|
| 535 |
// update parents
|
| 536 |
// update immediately if there is no parse_seq
|
| 537 |
// update only at barriers if there is parse_seq
|
| 538 |
-
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
|
| 539 |
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
|
| 540 |
int update_end = alloc->parse_seq_len ? ind : ind + 1;
|
| 541 |
for (int i = update_start; i < update_end; i++) {
|
|
@@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|
| 554 |
|
| 555 |
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
| 556 |
if (ggml_is_view(parent)) {
|
| 557 |
-
struct ggml_tensor * view_src =
|
| 558 |
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 559 |
view_src_hn->n_views -= 1;
|
| 560 |
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
| 561 |
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
| 562 |
-
|
| 563 |
}
|
| 564 |
}
|
| 565 |
else {
|
| 566 |
if (parent->data != node->data) {
|
| 567 |
-
|
| 568 |
}
|
| 569 |
}
|
| 570 |
}
|
|
@@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|
| 581 |
for (int i = 0; outputs[g][i] != NULL; i++) {
|
| 582 |
struct ggml_tensor * output = outputs[g][i];
|
| 583 |
AT_PRINTF("output: %s\n", output->name);
|
| 584 |
-
|
| 585 |
}
|
| 586 |
}
|
| 587 |
}
|
|
@@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|
| 590 |
}
|
| 591 |
|
| 592 |
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
| 593 |
-
return
|
| 594 |
}
|
|
|
|
| 6 |
#include <stdlib.h>
|
| 7 |
#include <string.h>
|
| 8 |
|
| 9 |
+
#ifdef __has_include
|
| 10 |
+
#if __has_include(<unistd.h>)
|
| 11 |
+
#include <unistd.h>
|
| 12 |
+
#if defined(_POSIX_MAPPED_FILES)
|
| 13 |
+
#include <sys/types.h>
|
| 14 |
+
#include <sys/mman.h>
|
| 15 |
+
#endif
|
| 16 |
+
#endif
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#if defined(_WIN32)
|
| 20 |
+
#define WIN32_LEAN_AND_MEAN
|
| 21 |
+
#ifndef NOMINMAX
|
| 22 |
+
#define NOMINMAX
|
| 23 |
+
#endif
|
| 24 |
+
#include <windows.h>
|
| 25 |
+
#include <memoryapi.h>
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
|
| 29 |
#define UNUSED(x) (void)(x)
|
| 30 |
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 31 |
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
|
|
|
| 119 |
}
|
| 120 |
#endif
|
| 121 |
|
| 122 |
+
static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
|
|
|
| 123 |
return ggml_nbytes(tensor);
|
| 124 |
|
| 125 |
UNUSED(alloc);
|
| 126 |
}
|
| 127 |
|
| 128 |
+
// check if a tensor is allocated by this buffer
|
| 129 |
+
static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
|
| 130 |
+
void * ptr = tensor->data;
|
| 131 |
+
return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static bool ggml_is_view(struct ggml_tensor * t) {
|
| 135 |
+
return t->view_src != NULL;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 139 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 140 |
+
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
|
| 141 |
+
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
|
| 142 |
+
#endif
|
| 143 |
+
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
| 144 |
size = aligned_offset(NULL, size, alloc->alignment);
|
| 145 |
|
| 146 |
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
|
|
|
| 164 |
if (best_fit_block == -1) {
|
| 165 |
// the last block is our last resort
|
| 166 |
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
| 167 |
+
max_avail = MAX(max_avail, block->size);
|
| 168 |
if (block->size >= size) {
|
| 169 |
best_fit_block = alloc->n_free_blocks - 1;
|
|
|
|
| 170 |
} else {
|
| 171 |
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
| 172 |
__func__, size, max_avail);
|
| 173 |
GGML_ASSERT(!"not enough space in the buffer");
|
| 174 |
+
return;
|
| 175 |
}
|
| 176 |
}
|
| 177 |
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
|
|
|
| 206 |
}
|
| 207 |
|
| 208 |
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
| 209 |
+
static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 210 |
void * ptr = tensor->data;
|
| 211 |
|
| 212 |
+
if (ggml_allocr_is_own(alloc, tensor) == false) {
|
| 213 |
// the tensor was not allocated in this buffer
|
| 214 |
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
| 215 |
// the easiest way to deal with this is just to ignore it
|
| 216 |
return;
|
| 217 |
}
|
| 218 |
|
| 219 |
+
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
| 220 |
size = aligned_offset(NULL, size, alloc->alignment);
|
| 221 |
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
| 222 |
|
|
|
|
| 310 |
return alloc;
|
| 311 |
}
|
| 312 |
|
| 313 |
+
// OS specific functions to allocate and free uncommitted virtual memory
|
| 314 |
+
static void * alloc_vmem(size_t size) {
|
| 315 |
+
#if defined(_WIN32)
|
| 316 |
+
return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
|
| 317 |
+
#elif defined(_POSIX_MAPPED_FILES)
|
| 318 |
+
void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
|
| 319 |
+
if (ptr == MAP_FAILED) {
|
| 320 |
+
return NULL;
|
| 321 |
+
}
|
| 322 |
+
return ptr;
|
| 323 |
+
#else
|
| 324 |
+
// use a fixed address for other platforms
|
| 325 |
+
uintptr_t base_addr = (uintptr_t)-size - 0x100;
|
| 326 |
+
return (void *)base_addr;
|
| 327 |
+
#endif
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
static void free_vmem(void * base_addr, size_t size) {
|
| 331 |
+
#if defined(_WIN32)
|
| 332 |
+
VirtualFree(base_addr, 0, MEM_RELEASE);
|
| 333 |
+
UNUSED(size);
|
| 334 |
+
#elif defined(_POSIX_MAPPED_FILES)
|
| 335 |
+
munmap(base_addr, size);
|
| 336 |
+
#else
|
| 337 |
+
// nothing to do
|
| 338 |
+
UNUSED(base_addr);
|
| 339 |
+
UNUSED(size);
|
| 340 |
+
#endif
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
// allocate uncommitted virtual memory to measure the size of the graph
|
| 344 |
+
static void alloc_measure_vmem(void ** base_addr, size_t * size) {
|
| 345 |
+
// 128GB for 64-bit, 1GB for 32-bit
|
| 346 |
+
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
|
| 347 |
+
do {
|
| 348 |
+
*base_addr = alloc_vmem(*size);
|
| 349 |
+
if (*base_addr != NULL) {
|
| 350 |
+
AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
|
| 351 |
+
return;
|
| 352 |
+
}
|
| 353 |
+
// try again with half the size
|
| 354 |
+
*size /= 2;
|
| 355 |
+
} while (*size > 0);
|
| 356 |
+
|
| 357 |
+
GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
static void free_measure_vmem(void * base_addr, size_t size) {
|
| 361 |
+
free_vmem(base_addr, size);
|
| 362 |
+
}
|
| 363 |
|
| 364 |
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
| 365 |
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
| 366 |
|
| 367 |
+
void * base_addr;
|
| 368 |
+
size_t size;
|
| 369 |
+
|
| 370 |
+
alloc_measure_vmem(&base_addr, &size);
|
| 371 |
+
|
| 372 |
*alloc = (struct ggml_allocr){
|
| 373 |
+
/*.data = */ base_addr,
|
| 374 |
+
/*.size = */ size,
|
| 375 |
/*.alignment = */ alignment,
|
| 376 |
/*.n_free_blocks = */ 0,
|
| 377 |
/*.free_blocks = */ {{0}},
|
|
|
|
| 391 |
}
|
| 392 |
|
| 393 |
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
| 394 |
+
if (alloc->measure) {
|
| 395 |
+
free_measure_vmem(alloc->data, alloc->size);
|
| 396 |
+
}
|
| 397 |
free(alloc);
|
| 398 |
}
|
| 399 |
|
|
|
|
| 403 |
|
| 404 |
//////////// compute graph allocator
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 407 |
if (a->type != b->type) {
|
| 408 |
return false;
|
|
|
|
| 418 |
return true;
|
| 419 |
}
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
static bool ggml_op_can_inplace(enum ggml_op op) {
|
| 422 |
switch (op) {
|
| 423 |
case GGML_OP_SCALE:
|
|
|
|
| 425 |
case GGML_OP_DIAG_MASK_INF:
|
| 426 |
case GGML_OP_ADD:
|
| 427 |
case GGML_OP_ADD1:
|
|
|
|
| 428 |
case GGML_OP_SUB:
|
| 429 |
case GGML_OP_MUL:
|
| 430 |
case GGML_OP_DIV:
|
|
|
|
| 434 |
case GGML_OP_UNARY:
|
| 435 |
case GGML_OP_ROPE:
|
| 436 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 437 |
case GGML_OP_SOFT_MAX:
|
| 438 |
case GGML_OP_CONT:
|
|
|
|
| 439 |
return true;
|
| 440 |
|
| 441 |
default:
|
|
|
|
| 447 |
struct hash_node * ht = alloc->hash_table;
|
| 448 |
if (node->data == NULL) {
|
| 449 |
if (ggml_is_view(node)) {
|
| 450 |
+
assert(node->view_src->data != NULL);
|
| 451 |
+
node->data = (char *)node->view_src->data + node->view_offs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
} else {
|
| 453 |
// see if we can reuse a parent's buffer (inplace)
|
| 454 |
if (ggml_op_can_inplace(node->op)) {
|
|
|
|
| 459 |
}
|
| 460 |
|
| 461 |
// if the node's data is external, then we cannot re-use it
|
| 462 |
+
if (ggml_allocr_is_own(alloc, parent) == false) {
|
|
|
|
| 463 |
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
| 464 |
continue;
|
| 465 |
}
|
|
|
|
| 467 |
struct hash_node * p_hn = hash_get(ht, parent);
|
| 468 |
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
| 469 |
if (ggml_is_view(parent)) {
|
| 470 |
+
struct ggml_tensor * view_src = parent->view_src;
|
| 471 |
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 472 |
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
| 473 |
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
|
|
|
| 493 |
}
|
| 494 |
}
|
| 495 |
|
| 496 |
+
static size_t ggml_allocr_alloc_graph_tensors_n(
|
| 497 |
struct ggml_allocr * alloc,
|
| 498 |
struct ggml_cgraph ** graphs, int n_graphs,
|
| 499 |
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
|
|
|
| 509 |
struct ggml_tensor * node = gf->nodes[i];
|
| 510 |
|
| 511 |
if (ggml_is_view(node)) {
|
| 512 |
+
struct ggml_tensor * view_src = node->view_src;
|
| 513 |
hash_get(ht, view_src)->n_views += 1;
|
| 514 |
}
|
| 515 |
|
|
|
|
| 571 |
AT_PRINTF("\n");
|
| 572 |
}
|
| 573 |
|
|
|
|
| 574 |
// update parents
|
| 575 |
// update immediately if there is no parse_seq
|
| 576 |
// update only at barriers if there is parse_seq
|
| 577 |
+
if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
|
| 578 |
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
|
| 579 |
int update_end = alloc->parse_seq_len ? ind : ind + 1;
|
| 580 |
for (int i = update_start; i < update_end; i++) {
|
|
|
|
| 593 |
|
| 594 |
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
| 595 |
if (ggml_is_view(parent)) {
|
| 596 |
+
struct ggml_tensor * view_src = parent->view_src;
|
| 597 |
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 598 |
view_src_hn->n_views -= 1;
|
| 599 |
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
| 600 |
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
| 601 |
+
ggml_allocr_free_tensor(alloc, view_src);
|
| 602 |
}
|
| 603 |
}
|
| 604 |
else {
|
| 605 |
if (parent->data != node->data) {
|
| 606 |
+
ggml_allocr_free_tensor(alloc, parent);
|
| 607 |
}
|
| 608 |
}
|
| 609 |
}
|
|
|
|
| 620 |
for (int i = 0; outputs[g][i] != NULL; i++) {
|
| 621 |
struct ggml_tensor * output = outputs[g][i];
|
| 622 |
AT_PRINTF("output: %s\n", output->name);
|
| 623 |
+
ggml_allocr_free_tensor(alloc, output);
|
| 624 |
}
|
| 625 |
}
|
| 626 |
}
|
|
|
|
| 629 |
}
|
| 630 |
|
| 631 |
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
| 632 |
+
return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
| 633 |
}
|
|
@@ -63,7 +63,10 @@ struct ggml_metal_context {
|
|
| 63 |
GGML_METAL_DECL_KERNEL(relu);
|
| 64 |
GGML_METAL_DECL_KERNEL(gelu);
|
| 65 |
GGML_METAL_DECL_KERNEL(soft_max);
|
|
|
|
| 66 |
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
|
|
|
|
|
| 67 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 68 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 69 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
@@ -77,6 +80,7 @@ struct ggml_metal_context {
|
|
| 77 |
GGML_METAL_DECL_KERNEL(norm);
|
| 78 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 79 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
|
|
|
| 80 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 81 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
| 82 |
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
|
@@ -117,14 +121,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
| 117 |
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
| 118 |
metal_printf("%s: allocating\n", __func__);
|
| 119 |
|
| 120 |
-
// Show all the Metal device instances in the system
|
| 121 |
-
NSArray * devices = MTLCopyAllDevices();
|
| 122 |
id <MTLDevice> device;
|
| 123 |
NSString * s;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
for (device in devices) {
|
| 125 |
s = [device name];
|
| 126 |
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
| 127 |
}
|
|
|
|
| 128 |
|
| 129 |
// Pick and show default Metal device
|
| 130 |
device = MTLCreateSystemDefaultDevice();
|
|
@@ -139,14 +146,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 139 |
ctx->n_buffers = 0;
|
| 140 |
ctx->concur_list_len = 0;
|
| 141 |
|
| 142 |
-
ctx->d_queue = dispatch_queue_create("
|
| 143 |
|
| 144 |
-
#
|
| 145 |
-
//
|
| 146 |
{
|
| 147 |
NSError * error = nil;
|
| 148 |
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if (error) {
|
| 151 |
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 152 |
return NULL;
|
|
@@ -161,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 161 |
|
| 162 |
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
| 163 |
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
| 164 |
-
NSString * path
|
| 165 |
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
| 166 |
|
| 167 |
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
|
@@ -207,7 +222,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 207 |
GGML_METAL_ADD_KERNEL(relu);
|
| 208 |
GGML_METAL_ADD_KERNEL(gelu);
|
| 209 |
GGML_METAL_ADD_KERNEL(soft_max);
|
|
|
|
| 210 |
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
|
|
|
|
|
| 211 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 212 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 213 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
@@ -221,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 221 |
GGML_METAL_ADD_KERNEL(norm);
|
| 222 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 223 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
|
|
|
| 224 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 225 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
| 226 |
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
|
@@ -247,13 +266,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 247 |
#undef GGML_METAL_ADD_KERNEL
|
| 248 |
}
|
| 249 |
|
| 250 |
-
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
| 251 |
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
|
|
|
|
|
| 252 |
if (ctx->device.maxTransferRate != 0) {
|
| 253 |
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
| 254 |
} else {
|
| 255 |
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
| 256 |
}
|
|
|
|
| 257 |
|
| 258 |
return ctx;
|
| 259 |
}
|
|
@@ -273,7 +294,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 273 |
GGML_METAL_DEL_KERNEL(relu);
|
| 274 |
GGML_METAL_DEL_KERNEL(gelu);
|
| 275 |
GGML_METAL_DEL_KERNEL(soft_max);
|
|
|
|
| 276 |
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
|
|
|
|
|
| 277 |
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
| 278 |
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
| 279 |
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
@@ -287,6 +311,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 287 |
GGML_METAL_DEL_KERNEL(norm);
|
| 288 |
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
| 289 |
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
|
|
|
| 290 |
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
| 291 |
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
| 292 |
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
|
@@ -365,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
| 365 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 366 |
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
| 367 |
|
|
|
|
| 368 |
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
| 369 |
*offs = (size_t) ioffs;
|
| 370 |
|
|
@@ -454,6 +480,7 @@ bool ggml_metal_add_buffer(
|
|
| 454 |
}
|
| 455 |
}
|
| 456 |
|
|
|
|
| 457 |
metal_printf(", (%8.2f / %8.2f)",
|
| 458 |
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
| 459 |
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
@@ -463,6 +490,9 @@ bool ggml_metal_add_buffer(
|
|
| 463 |
} else {
|
| 464 |
metal_printf("\n");
|
| 465 |
}
|
|
|
|
|
|
|
|
|
|
| 466 |
}
|
| 467 |
|
| 468 |
return true;
|
|
@@ -698,6 +728,7 @@ void ggml_metal_graph_compute(
|
|
| 698 |
case GGML_OP_ADD:
|
| 699 |
{
|
| 700 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
| 701 |
|
| 702 |
// utilize float4
|
| 703 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -705,6 +736,7 @@ void ggml_metal_graph_compute(
|
|
| 705 |
|
| 706 |
if (ggml_nelements(src1) == ne10) {
|
| 707 |
// src1 is a row
|
|
|
|
| 708 |
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
| 709 |
} else {
|
| 710 |
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
@@ -721,6 +753,7 @@ void ggml_metal_graph_compute(
|
|
| 721 |
case GGML_OP_MUL:
|
| 722 |
{
|
| 723 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
| 724 |
|
| 725 |
// utilize float4
|
| 726 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -728,6 +761,7 @@ void ggml_metal_graph_compute(
|
|
| 728 |
|
| 729 |
if (ggml_nelements(src1) == ne10) {
|
| 730 |
// src1 is a row
|
|
|
|
| 731 |
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
| 732 |
} else {
|
| 733 |
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
@@ -743,6 +777,8 @@ void ggml_metal_graph_compute(
|
|
| 743 |
} break;
|
| 744 |
case GGML_OP_SCALE:
|
| 745 |
{
|
|
|
|
|
|
|
| 746 |
const float scale = *(const float *) src1->data;
|
| 747 |
|
| 748 |
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
@@ -750,7 +786,7 @@ void ggml_metal_graph_compute(
|
|
| 750 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 751 |
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
| 752 |
|
| 753 |
-
const int64_t n = ggml_nelements(dst);
|
| 754 |
|
| 755 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 756 |
} break;
|
|
@@ -762,7 +798,7 @@ void ggml_metal_graph_compute(
|
|
| 762 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 763 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 764 |
|
| 765 |
-
const int64_t n = ggml_nelements(dst);
|
| 766 |
|
| 767 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 768 |
} break;
|
|
@@ -782,7 +818,7 @@ void ggml_metal_graph_compute(
|
|
| 782 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 783 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 784 |
|
| 785 |
-
const int64_t n = ggml_nelements(dst);
|
| 786 |
|
| 787 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 788 |
} break;
|
|
@@ -796,13 +832,16 @@ void ggml_metal_graph_compute(
|
|
| 796 |
{
|
| 797 |
const int nth = 32;
|
| 798 |
|
| 799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 801 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 802 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 803 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 804 |
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 805 |
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
| 806 |
|
| 807 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 808 |
} break;
|
|
@@ -810,14 +849,23 @@ void ggml_metal_graph_compute(
|
|
| 810 |
{
|
| 811 |
const int n_past = ((int32_t *)(dst->op_params))[0];
|
| 812 |
|
| 813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 815 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 816 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 817 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 818 |
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
| 819 |
|
| 820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
} break;
|
| 822 |
case GGML_OP_MUL_MAT:
|
| 823 |
{
|
|
@@ -830,8 +878,8 @@ void ggml_metal_graph_compute(
|
|
| 830 |
|
| 831 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 832 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 833 |
-
if (
|
| 834 |
-
|
| 835 |
src1t == GGML_TYPE_F32 &&
|
| 836 |
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 837 |
ne00%32 == 0 &&
|
|
@@ -856,14 +904,18 @@ void ggml_metal_graph_compute(
|
|
| 856 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
| 857 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
| 858 |
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
| 859 |
-
[encoder setBytes:&
|
| 860 |
-
[encoder setBytes:&
|
| 861 |
-
[encoder setBytes:&
|
|
|
|
|
|
|
|
|
|
| 862 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 863 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 864 |
} else {
|
| 865 |
int nth0 = 32;
|
| 866 |
int nth1 = 1;
|
|
|
|
| 867 |
|
| 868 |
// use custom matrix x vector kernel
|
| 869 |
switch (src0t) {
|
|
@@ -873,8 +925,14 @@ void ggml_metal_graph_compute(
|
|
| 873 |
nth1 = 1;
|
| 874 |
if (ne11 * ne12 < 4) {
|
| 875 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
} else {
|
| 877 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
|
|
| 878 |
}
|
| 879 |
} break;
|
| 880 |
case GGML_TYPE_Q4_0:
|
|
@@ -995,7 +1053,7 @@ void ggml_metal_graph_compute(
|
|
| 995 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 996 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 997 |
} else {
|
| 998 |
-
int64_t ny = (ne11 +
|
| 999 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1000 |
}
|
| 1001 |
}
|
|
@@ -1003,6 +1061,7 @@ void ggml_metal_graph_compute(
|
|
| 1003 |
case GGML_OP_GET_ROWS:
|
| 1004 |
{
|
| 1005 |
switch (src0->type) {
|
|
|
|
| 1006 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
| 1007 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 1008 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
@@ -1018,9 +1077,9 @@ void ggml_metal_graph_compute(
|
|
| 1018 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1019 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1020 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1021 |
-
[encoder setBytes:&
|
| 1022 |
-
[encoder setBytes:&
|
| 1023 |
-
[encoder setBytes:&
|
| 1024 |
|
| 1025 |
const int64_t n = ggml_nelements(src1);
|
| 1026 |
|
|
|
|
| 63 |
GGML_METAL_DECL_KERNEL(relu);
|
| 64 |
GGML_METAL_DECL_KERNEL(gelu);
|
| 65 |
GGML_METAL_DECL_KERNEL(soft_max);
|
| 66 |
+
GGML_METAL_DECL_KERNEL(soft_max_4);
|
| 67 |
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
| 68 |
+
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
| 69 |
+
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
| 70 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 71 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 72 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
|
|
| 80 |
GGML_METAL_DECL_KERNEL(norm);
|
| 81 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 82 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
| 83 |
+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
| 84 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 85 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
| 86 |
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
|
|
|
| 121 |
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
| 122 |
metal_printf("%s: allocating\n", __func__);
|
| 123 |
|
|
|
|
|
|
|
| 124 |
id <MTLDevice> device;
|
| 125 |
NSString * s;
|
| 126 |
+
|
| 127 |
+
#if TARGET_OS_OSX
|
| 128 |
+
// Show all the Metal device instances in the system
|
| 129 |
+
NSArray * devices = MTLCopyAllDevices();
|
| 130 |
for (device in devices) {
|
| 131 |
s = [device name];
|
| 132 |
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
| 133 |
}
|
| 134 |
+
#endif
|
| 135 |
|
| 136 |
// Pick and show default Metal device
|
| 137 |
device = MTLCreateSystemDefaultDevice();
|
|
|
|
| 146 |
ctx->n_buffers = 0;
|
| 147 |
ctx->concur_list_len = 0;
|
| 148 |
|
| 149 |
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
| 150 |
|
| 151 |
+
#ifdef GGML_SWIFT
|
| 152 |
+
// load the default.metallib file
|
| 153 |
{
|
| 154 |
NSError * error = nil;
|
| 155 |
|
| 156 |
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
| 157 |
+
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
| 158 |
+
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
| 159 |
+
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
| 160 |
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
| 161 |
+
|
| 162 |
+
// Load the metallib file into a Metal library
|
| 163 |
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
| 164 |
+
|
| 165 |
if (error) {
|
| 166 |
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 167 |
return NULL;
|
|
|
|
| 176 |
|
| 177 |
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
| 178 |
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
| 179 |
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
| 180 |
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
| 181 |
|
| 182 |
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
|
|
|
| 222 |
GGML_METAL_ADD_KERNEL(relu);
|
| 223 |
GGML_METAL_ADD_KERNEL(gelu);
|
| 224 |
GGML_METAL_ADD_KERNEL(soft_max);
|
| 225 |
+
GGML_METAL_ADD_KERNEL(soft_max_4);
|
| 226 |
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
| 227 |
+
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
| 228 |
+
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
| 229 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 230 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 231 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
|
|
| 239 |
GGML_METAL_ADD_KERNEL(norm);
|
| 240 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 241 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
| 242 |
+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
| 243 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 244 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
| 245 |
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
|
|
|
| 266 |
#undef GGML_METAL_ADD_KERNEL
|
| 267 |
}
|
| 268 |
|
|
|
|
| 269 |
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
| 270 |
+
#if TARGET_OS_OSX
|
| 271 |
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
| 272 |
if (ctx->device.maxTransferRate != 0) {
|
| 273 |
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
| 274 |
} else {
|
| 275 |
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
| 276 |
}
|
| 277 |
+
#endif
|
| 278 |
|
| 279 |
return ctx;
|
| 280 |
}
|
|
|
|
| 294 |
GGML_METAL_DEL_KERNEL(relu);
|
| 295 |
GGML_METAL_DEL_KERNEL(gelu);
|
| 296 |
GGML_METAL_DEL_KERNEL(soft_max);
|
| 297 |
+
GGML_METAL_DEL_KERNEL(soft_max_4);
|
| 298 |
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
| 299 |
+
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
| 300 |
+
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
| 301 |
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
| 302 |
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
| 303 |
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
|
|
| 311 |
GGML_METAL_DEL_KERNEL(norm);
|
| 312 |
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
| 313 |
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
| 314 |
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
| 315 |
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
| 316 |
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
| 317 |
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
|
|
|
| 390 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 391 |
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
| 392 |
|
| 393 |
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
| 394 |
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
| 395 |
*offs = (size_t) ioffs;
|
| 396 |
|
|
|
|
| 480 |
}
|
| 481 |
}
|
| 482 |
|
| 483 |
+
#if TARGET_OS_OSX
|
| 484 |
metal_printf(", (%8.2f / %8.2f)",
|
| 485 |
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
| 486 |
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
|
|
| 490 |
} else {
|
| 491 |
metal_printf("\n");
|
| 492 |
}
|
| 493 |
+
#else
|
| 494 |
+
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
| 495 |
+
#endif
|
| 496 |
}
|
| 497 |
|
| 498 |
return true;
|
|
|
|
| 728 |
case GGML_OP_ADD:
|
| 729 |
{
|
| 730 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 731 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 732 |
|
| 733 |
// utilize float4
|
| 734 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
| 736 |
|
| 737 |
if (ggml_nelements(src1) == ne10) {
|
| 738 |
// src1 is a row
|
| 739 |
+
GGML_ASSERT(ne11 == 1);
|
| 740 |
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
| 741 |
} else {
|
| 742 |
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
|
|
| 753 |
case GGML_OP_MUL:
|
| 754 |
{
|
| 755 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 756 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 757 |
|
| 758 |
// utilize float4
|
| 759 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
| 761 |
|
| 762 |
if (ggml_nelements(src1) == ne10) {
|
| 763 |
// src1 is a row
|
| 764 |
+
GGML_ASSERT(ne11 == 1);
|
| 765 |
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
| 766 |
} else {
|
| 767 |
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
|
|
| 777 |
} break;
|
| 778 |
case GGML_OP_SCALE:
|
| 779 |
{
|
| 780 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 781 |
+
|
| 782 |
const float scale = *(const float *) src1->data;
|
| 783 |
|
| 784 |
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
|
|
| 786 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 787 |
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
| 788 |
|
| 789 |
+
const int64_t n = ggml_nelements(dst)/4;
|
| 790 |
|
| 791 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 792 |
} break;
|
|
|
|
| 798 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 799 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 800 |
|
| 801 |
+
const int64_t n = ggml_nelements(dst)/4;
|
| 802 |
|
| 803 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 804 |
} break;
|
|
|
|
| 818 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 819 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 820 |
|
| 821 |
+
const int64_t n = ggml_nelements(dst)/4;
|
| 822 |
|
| 823 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 824 |
} break;
|
|
|
|
| 832 |
{
|
| 833 |
const int nth = 32;
|
| 834 |
|
| 835 |
+
if (ne00%4 == 0) {
|
| 836 |
+
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
| 837 |
+
} else {
|
| 838 |
+
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
| 839 |
+
}
|
| 840 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 841 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 842 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 843 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 844 |
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
| 845 |
|
| 846 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 847 |
} break;
|
|
|
|
| 849 |
{
|
| 850 |
const int n_past = ((int32_t *)(dst->op_params))[0];
|
| 851 |
|
| 852 |
+
if (ne00%8 == 0) {
|
| 853 |
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
| 854 |
+
} else {
|
| 855 |
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
| 856 |
+
}
|
| 857 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 858 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 859 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 860 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 861 |
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
| 862 |
|
| 863 |
+
if (ne00%8 == 0) {
|
| 864 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 865 |
+
}
|
| 866 |
+
else {
|
| 867 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 868 |
+
}
|
| 869 |
} break;
|
| 870 |
case GGML_OP_MUL_MAT:
|
| 871 |
{
|
|
|
|
| 878 |
|
| 879 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 880 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 881 |
+
if (!ggml_is_transposed(src0) &&
|
| 882 |
+
!ggml_is_transposed(src1) &&
|
| 883 |
src1t == GGML_TYPE_F32 &&
|
| 884 |
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 885 |
ne00%32 == 0 &&
|
|
|
|
| 904 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
| 905 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
| 906 |
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
| 907 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
| 908 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
| 909 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
| 910 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
| 911 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
| 912 |
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
| 913 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 914 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 915 |
} else {
|
| 916 |
int nth0 = 32;
|
| 917 |
int nth1 = 1;
|
| 918 |
+
int nrows = 1;
|
| 919 |
|
| 920 |
// use custom matrix x vector kernel
|
| 921 |
switch (src0t) {
|
|
|
|
| 925 |
nth1 = 1;
|
| 926 |
if (ne11 * ne12 < 4) {
|
| 927 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
| 928 |
+
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 929 |
+
} else if (false) {
|
| 930 |
+
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
|
| 931 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
| 932 |
+
nrows = ne11;
|
| 933 |
} else {
|
| 934 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
| 935 |
+
nrows = 4;
|
| 936 |
}
|
| 937 |
} break;
|
| 938 |
case GGML_TYPE_Q4_0:
|
|
|
|
| 1053 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 1054 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1055 |
} else {
|
| 1056 |
+
int64_t ny = (ne11 + nrows - 1)/nrows;
|
| 1057 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1058 |
}
|
| 1059 |
}
|
|
|
|
| 1061 |
case GGML_OP_GET_ROWS:
|
| 1062 |
{
|
| 1063 |
switch (src0->type) {
|
| 1064 |
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
| 1065 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
| 1066 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 1067 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
|
|
| 1077 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1078 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1079 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1080 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
| 1081 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
| 1082 |
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
| 1083 |
|
| 1084 |
const int64_t n = ggml_nelements(src1);
|
| 1085 |
|
|
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
|
|
| 38 |
device const float4 * src0,
|
| 39 |
device const float4 * src1,
|
| 40 |
device float4 * dst,
|
| 41 |
-
constant
|
| 42 |
uint tpig[[thread_position_in_grid]]) {
|
| 43 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 44 |
}
|
|
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|
| 63 |
}
|
| 64 |
|
| 65 |
kernel void kernel_scale(
|
| 66 |
-
device const
|
| 67 |
-
device
|
| 68 |
constant float & scale,
|
| 69 |
uint tpig[[thread_position_in_grid]]) {
|
| 70 |
dst[tpig] = src0[tpig] * scale;
|
| 71 |
}
|
| 72 |
|
| 73 |
kernel void kernel_silu(
|
| 74 |
-
device const
|
| 75 |
-
device
|
| 76 |
uint tpig[[thread_position_in_grid]]) {
|
| 77 |
-
|
| 78 |
dst[tpig] = x / (1.0f + exp(-x));
|
| 79 |
}
|
| 80 |
|
|
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|
| 89 |
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 90 |
|
| 91 |
kernel void kernel_gelu(
|
| 92 |
-
device const
|
| 93 |
-
device
|
| 94 |
uint tpig[[thread_position_in_grid]]) {
|
| 95 |
-
|
| 96 |
|
| 97 |
// BEWARE !!!
|
| 98 |
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|
| 107 |
constant int64_t & ne00,
|
| 108 |
constant int64_t & ne01,
|
| 109 |
constant int64_t & ne02,
|
| 110 |
-
threadgroup float * buf [[threadgroup(0)]],
|
| 111 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 112 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 113 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
|
|
| 119 |
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 120 |
|
| 121 |
// parallel max
|
| 122 |
-
|
| 123 |
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 124 |
-
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
// reduce
|
| 128 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 129 |
-
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
| 130 |
-
if (tpitg[0] < i) {
|
| 131 |
-
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
| 132 |
-
}
|
| 133 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 134 |
}
|
| 135 |
-
|
| 136 |
-
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
| 137 |
-
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
| 138 |
-
//if (tpitg[0] == 0) {
|
| 139 |
-
// buf[0] = buf[0];
|
| 140 |
-
//}
|
| 141 |
-
|
| 142 |
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 143 |
-
|
| 144 |
-
const float max = buf[0];
|
| 145 |
|
| 146 |
// parallel sum
|
| 147 |
-
|
| 148 |
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 149 |
const float exp_psrc0 = exp(psrc0[i00] - max);
|
| 150 |
-
|
| 151 |
// Remember the result of exp here. exp is expensive, so we really do not
|
| 152 |
// whish to compute it twice.
|
| 153 |
pdst[i00] = exp_psrc0;
|
| 154 |
}
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
for (
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
}
|
|
|
|
| 164 |
|
| 165 |
-
|
| 166 |
-
//// broadcast
|
| 167 |
-
//if (tpitg[0] == 0) {
|
| 168 |
-
// buf[0] = buf[0];
|
| 169 |
-
//}
|
| 170 |
|
| 171 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
const float sum =
|
| 174 |
|
| 175 |
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 176 |
-
|
| 177 |
}
|
| 178 |
}
|
| 179 |
|
|
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
|
| 192 |
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
| 193 |
} else {
|
| 194 |
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
}
|
| 196 |
}
|
| 197 |
|
|
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
|
|
| 616 |
}
|
| 617 |
}
|
| 618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
kernel void kernel_alibi_f32(
|
| 620 |
device const float * src0,
|
| 621 |
device float * dst,
|
|
@@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
| 1123 |
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
| 1124 |
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1125 |
|
| 1126 |
-
float yl[
|
| 1127 |
|
| 1128 |
-
const uint16_t kmask1 =
|
| 1129 |
const uint16_t kmask2 = 0x0f0f;
|
| 1130 |
|
| 1131 |
-
const int tid = tiisg/
|
| 1132 |
-
const int ix = tiisg%
|
| 1133 |
-
const int ip = tid/
|
| 1134 |
-
const int il = tid/2
|
| 1135 |
const int ir = tid%2;
|
| 1136 |
const int n = 8;
|
| 1137 |
const int l0 = n*ir;
|
| 1138 |
|
| 1139 |
-
|
| 1140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
|
| 1142 |
const int shift = 2*il;
|
| 1143 |
-
const
|
| 1144 |
-
const
|
| 1145 |
-
const int32_t v1 = 4 << shift;
|
| 1146 |
-
const int32_t v2 = 1024 << shift;
|
| 1147 |
|
| 1148 |
const uint16_t s_shift1 = 4*ip;
|
| 1149 |
-
const uint16_t s_shift2 = s_shift1 +
|
| 1150 |
-
const int ik = 4 + (il%2);
|
| 1151 |
|
| 1152 |
const int q_offset = 32*ip + l0;
|
| 1153 |
const int y_offset = 128*ip + 32*il + l0;
|
|
@@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
| 1156 |
|
| 1157 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 1158 |
|
| 1159 |
-
|
| 1160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
|
| 1162 |
for (int l = 0; l < 8; ++l) {
|
| 1163 |
-
yl[l+0] = y1[l+ 0];
|
| 1164 |
-
yl[l+8] = y1[l+16];
|
|
|
|
|
|
|
| 1165 |
}
|
| 1166 |
|
| 1167 |
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
|
@@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
| 1172 |
for (int row = 0; row < 2; ++row) {
|
| 1173 |
|
| 1174 |
const float d_all = (float)dh[0];
|
| 1175 |
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
| 1176 |
|
| 1177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1178 |
for (int l = 0; l < n; l += 2) {
|
| 1179 |
-
const
|
| 1180 |
-
s1 += yl[l+0] * (
|
| 1181 |
-
s2 += yl[l+1] * (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1182 |
}
|
| 1183 |
-
float
|
| 1184 |
-
|
| 1185 |
-
|
|
|
|
| 1186 |
|
| 1187 |
-
s1 = s2 = 0;
|
| 1188 |
for (int l = 0; l < n; l += 2) {
|
| 1189 |
-
const
|
| 1190 |
-
s1 += yl[l+8] * (
|
| 1191 |
-
s2 += yl[l+9] * (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1192 |
}
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
|
|
|
| 1196 |
|
| 1197 |
q += step;
|
| 1198 |
h += step;
|
|
@@ -1201,15 +1308,17 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
| 1201 |
|
| 1202 |
}
|
| 1203 |
|
| 1204 |
-
y1 +=
|
| 1205 |
|
| 1206 |
}
|
| 1207 |
|
| 1208 |
for (int row = 0; row < 2; ++row) {
|
| 1209 |
-
const float sumf = (sumf1[row]
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
|
|
|
|
|
|
| 1213 |
}
|
| 1214 |
}
|
| 1215 |
}
|
|
@@ -1564,17 +1673,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
| 1564 |
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
| 1565 |
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
| 1566 |
|
| 1567 |
-
float4
|
|
|
|
| 1568 |
for (int l = 0; l < n; ++l) {
|
| 1569 |
uint8_t h = qh[l];
|
| 1570 |
-
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1574 |
}
|
| 1575 |
const float dall = dh[0];
|
| 1576 |
const float dmin = dh[1];
|
| 1577 |
-
sumf[row] += dall * (
|
|
|
|
|
|
|
|
|
|
| 1578 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 1579 |
|
| 1580 |
q1 += step;
|
|
@@ -1747,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1747 |
|
| 1748 |
//============================= templates and their specializations =============================
|
| 1749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1750 |
template <typename type4x4>
|
| 1751 |
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
| 1752 |
half4x4 temp = *(((device half4x4 *)src));
|
|
@@ -1758,28 +1884,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
| 1758 |
template <typename type4x4>
|
| 1759 |
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 1760 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 1761 |
-
const
|
| 1762 |
-
const
|
|
|
|
| 1763 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1764 |
-
const ushort mask1 =
|
| 1765 |
|
| 1766 |
for (int i=0;i<8;i++) {
|
| 1767 |
-
reg[i/2][2*(i%2)]
|
| 1768 |
-
reg[i/2][2*(i%2)+1] = (
|
| 1769 |
}
|
| 1770 |
}
|
| 1771 |
|
| 1772 |
template <typename type4x4>
|
| 1773 |
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
| 1774 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 1775 |
-
const
|
| 1776 |
-
const
|
|
|
|
| 1777 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1778 |
-
const ushort mask1 =
|
| 1779 |
|
| 1780 |
for (int i=0;i<8;i++) {
|
| 1781 |
-
reg[i/2][2*(i%2)]
|
| 1782 |
-
reg[i/2][2*(i%2)+1] = ((
|
| 1783 |
}
|
| 1784 |
}
|
| 1785 |
|
|
@@ -1815,7 +1943,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
| 1815 |
|
| 1816 |
template <typename type4x4>
|
| 1817 |
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
| 1818 |
-
const
|
| 1819 |
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 1820 |
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
| 1821 |
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
@@ -1828,16 +1956,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
| 1828 |
((il/4)>0 ? 12 : 3);
|
| 1829 |
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
| 1830 |
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
| 1831 |
-
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
| 1832 |
-
|
| 1833 |
-
|
|
|
|
| 1834 |
|
| 1835 |
-
il = (il/2)
|
| 1836 |
-
|
| 1837 |
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
|
|
| 1838 |
|
| 1839 |
for (int i = 0; i < 16; ++i) {
|
| 1840 |
-
reg[i/4][i%4] =
|
| 1841 |
}
|
| 1842 |
#else
|
| 1843 |
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
|
@@ -1852,26 +1982,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
| 1852 |
#endif
|
| 1853 |
}
|
| 1854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1855 |
template <typename type4x4>
|
| 1856 |
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
| 1857 |
-
device const
|
| 1858 |
|
| 1859 |
#if QK_K == 256
|
| 1860 |
-
const float d = (float)(xb->d);
|
| 1861 |
-
const float min = (float)(xb->dmin);
|
| 1862 |
short is = (il/4) * 2;
|
| 1863 |
q = q + (il/4) * 32 + 16 * (il&1);
|
| 1864 |
-
il = il
|
| 1865 |
-
const
|
| 1866 |
-
const
|
| 1867 |
-
const
|
|
|
|
|
|
|
| 1868 |
#else
|
| 1869 |
q = q + 16 * (il&1);
|
| 1870 |
device const uint8_t * s = xb->scales;
|
| 1871 |
device const half2 * dh = (device const half2 *)xb->d;
|
| 1872 |
const float2 d = (float2)dh[0];
|
| 1873 |
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
| 1874 |
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1
|
| 1875 |
#endif
|
| 1876 |
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 1877 |
for (int i = 0; i < 16; ++i) {
|
|
@@ -1885,19 +2020,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
| 1885 |
device const uint8_t * qh = xb->qh;
|
| 1886 |
|
| 1887 |
#if QK_K == 256
|
| 1888 |
-
const float d = (float)(xb->d);
|
| 1889 |
-
const float min = (float)(xb->dmin);
|
| 1890 |
short is = (il/4) * 2;
|
| 1891 |
q = q + 32 * (il/4) + 16 * (il&1);
|
| 1892 |
qh = qh + 16 * (il&1);
|
| 1893 |
uint8_t ul = 1 << (il/2);
|
| 1894 |
-
il = il
|
| 1895 |
-
const
|
| 1896 |
-
const
|
| 1897 |
-
const
|
|
|
|
|
|
|
| 1898 |
|
| 1899 |
-
const ushort mask
|
| 1900 |
-
const
|
| 1901 |
for (int i = 0; i < 16; ++i) {
|
| 1902 |
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
| 1903 |
}
|
|
@@ -1916,7 +2051,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
| 1916 |
|
| 1917 |
template <typename type4x4>
|
| 1918 |
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
| 1919 |
-
const
|
| 1920 |
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
| 1921 |
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
| 1922 |
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
@@ -1924,19 +2059,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
| 1924 |
#if QK_K == 256
|
| 1925 |
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
| 1926 |
qh = qh + 32*(il/8) + 16*(il&1);
|
| 1927 |
-
|
| 1928 |
-
il = (il/2)
|
| 1929 |
#else
|
| 1930 |
ql = ql + 16 * (il&1);
|
| 1931 |
-
|
| 1932 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1933 |
for (int i = 0; i < 16; ++i) {
|
| 1934 |
-
|
| 1935 |
-
|
| 1936 |
-
|
| 1937 |
-
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
| 1938 |
-
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
| 1939 |
-
reg[i/4][i%4] = d_all * sc * q * coef;
|
| 1940 |
}
|
| 1941 |
}
|
| 1942 |
|
|
@@ -1976,22 +2113,25 @@ kernel void kernel_get_rows(
|
|
| 1976 |
// each block_q contains 16*nl weights
|
| 1977 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 1978 |
kernel void kernel_mul_mm(device const uchar * src0,
|
| 1979 |
-
|
| 1980 |
-
|
| 1981 |
-
|
| 1982 |
-
|
| 1983 |
-
|
| 1984 |
-
|
| 1985 |
-
|
| 1986 |
-
|
| 1987 |
-
|
| 1988 |
-
|
| 1989 |
-
|
| 1990 |
-
|
| 1991 |
-
|
| 1992 |
-
|
| 1993 |
-
|
| 1994 |
-
|
|
|
|
|
|
|
|
|
|
| 1995 |
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
| 1996 |
|
| 1997 |
const uint r0 = tgpig.y;
|
|
@@ -2004,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 2004 |
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 2005 |
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 2006 |
|
| 2007 |
-
simdgroup_half8x8
|
| 2008 |
simdgroup_float8x8 mb[2];
|
| 2009 |
simdgroup_float8x8 c_res[8];
|
| 2010 |
for (int i = 0; i < 8; i++){
|
|
@@ -2012,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 2012 |
}
|
| 2013 |
|
| 2014 |
short il = (tiitg % THREAD_PER_ROW);
|
| 2015 |
-
|
| 2016 |
-
|
| 2017 |
-
|
| 2018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2019 |
|
| 2020 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
| 2021 |
//load data and store to threadgroup memory
|
|
@@ -2095,6 +2240,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 2095 |
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
| 2096 |
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
| 2097 |
|
|
|
|
| 2098 |
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
| 2099 |
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
| 2100 |
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
@@ -2105,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
| 2105 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 2106 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 2107 |
|
| 2108 |
-
typedef void (mat_mm_t)(
|
| 2109 |
-
|
| 2110 |
-
|
| 2111 |
-
|
| 2112 |
-
|
| 2113 |
-
|
| 2114 |
-
|
| 2115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2116 |
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 2117 |
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 2118 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
|
| 38 |
device const float4 * src0,
|
| 39 |
device const float4 * src1,
|
| 40 |
device float4 * dst,
|
| 41 |
+
constant int64_t & nb,
|
| 42 |
uint tpig[[thread_position_in_grid]]) {
|
| 43 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 44 |
}
|
|
|
|
| 63 |
}
|
| 64 |
|
| 65 |
kernel void kernel_scale(
|
| 66 |
+
device const float4 * src0,
|
| 67 |
+
device float4 * dst,
|
| 68 |
constant float & scale,
|
| 69 |
uint tpig[[thread_position_in_grid]]) {
|
| 70 |
dst[tpig] = src0[tpig] * scale;
|
| 71 |
}
|
| 72 |
|
| 73 |
kernel void kernel_silu(
|
| 74 |
+
device const float4 * src0,
|
| 75 |
+
device float4 * dst,
|
| 76 |
uint tpig[[thread_position_in_grid]]) {
|
| 77 |
+
device const float4 & x = src0[tpig];
|
| 78 |
dst[tpig] = x / (1.0f + exp(-x));
|
| 79 |
}
|
| 80 |
|
|
|
|
| 89 |
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 90 |
|
| 91 |
kernel void kernel_gelu(
|
| 92 |
+
device const float4 * src0,
|
| 93 |
+
device float4 * dst,
|
| 94 |
uint tpig[[thread_position_in_grid]]) {
|
| 95 |
+
device const float4 & x = src0[tpig];
|
| 96 |
|
| 97 |
// BEWARE !!!
|
| 98 |
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
|
|
| 107 |
constant int64_t & ne00,
|
| 108 |
constant int64_t & ne01,
|
| 109 |
constant int64_t & ne02,
|
|
|
|
| 110 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 111 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 112 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 118 |
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 119 |
|
| 120 |
// parallel max
|
| 121 |
+
float lmax = psrc0[tpitg[0]];
|
| 122 |
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 123 |
+
lmax = MAX(lmax, psrc0[i00]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
}
|
| 125 |
+
const float max = simd_max(lmax);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
// parallel sum
|
| 128 |
+
float lsum = 0.0f;
|
| 129 |
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 130 |
const float exp_psrc0 = exp(psrc0[i00] - max);
|
| 131 |
+
lsum += exp_psrc0;
|
| 132 |
// Remember the result of exp here. exp is expensive, so we really do not
|
| 133 |
// whish to compute it twice.
|
| 134 |
pdst[i00] = exp_psrc0;
|
| 135 |
}
|
| 136 |
|
| 137 |
+
const float sum = simd_sum(lsum);
|
| 138 |
+
|
| 139 |
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
| 140 |
+
pdst[i00] /= sum;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
kernel void kernel_soft_max_4(
|
| 145 |
+
device const float * src0,
|
| 146 |
+
device float * dst,
|
| 147 |
+
constant int64_t & ne00,
|
| 148 |
+
constant int64_t & ne01,
|
| 149 |
+
constant int64_t & ne02,
|
| 150 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 151 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 152 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 153 |
+
const int64_t i03 = tgpig[2];
|
| 154 |
+
const int64_t i02 = tgpig[1];
|
| 155 |
+
const int64_t i01 = tgpig[0];
|
| 156 |
+
|
| 157 |
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 158 |
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 159 |
+
|
| 160 |
+
// parallel max
|
| 161 |
+
float4 lmax4 = psrc4[tpitg[0]];
|
| 162 |
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
| 163 |
+
lmax4 = fmax(lmax4, psrc4[i00]);
|
| 164 |
}
|
| 165 |
+
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
| 166 |
|
| 167 |
+
const float max = simd_max(lmax);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
// parallel sum
|
| 170 |
+
float4 lsum4 = 0.0f;
|
| 171 |
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
| 172 |
+
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
| 173 |
+
lsum4 += exp_psrc4;
|
| 174 |
+
pdst4[i00] = exp_psrc4;
|
| 175 |
+
}
|
| 176 |
+
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
| 177 |
|
| 178 |
+
const float sum = simd_sum(lsum);
|
| 179 |
|
| 180 |
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
| 181 |
+
pdst4[i00] /= sum;
|
| 182 |
}
|
| 183 |
}
|
| 184 |
|
|
|
|
| 197 |
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
| 198 |
} else {
|
| 199 |
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
kernel void kernel_diag_mask_inf_8(
|
| 204 |
+
device const float4 * src0,
|
| 205 |
+
device float4 * dst,
|
| 206 |
+
constant int64_t & ne00,
|
| 207 |
+
constant int64_t & ne01,
|
| 208 |
+
constant int & n_past,
|
| 209 |
+
uint3 tpig[[thread_position_in_grid]]) {
|
| 210 |
+
|
| 211 |
+
const int64_t i = 2*tpig[0];
|
| 212 |
+
|
| 213 |
+
dst[i+0] = src0[i+0];
|
| 214 |
+
dst[i+1] = src0[i+1];
|
| 215 |
+
int64_t i4 = 4*i;
|
| 216 |
+
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
| 217 |
+
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
| 218 |
+
const int64_t i00 = i4;
|
| 219 |
+
for (int k = 3; k >= 0; --k) {
|
| 220 |
+
if (i00 + 4 + k <= n_past + i01) {
|
| 221 |
+
break;
|
| 222 |
+
}
|
| 223 |
+
dst[i+1][k] = -INFINITY;
|
| 224 |
+
if (i00 + k > n_past + i01) {
|
| 225 |
+
dst[i][k] = -INFINITY;
|
| 226 |
+
}
|
| 227 |
}
|
| 228 |
}
|
| 229 |
|
|
|
|
| 648 |
}
|
| 649 |
}
|
| 650 |
|
| 651 |
+
// Assumes row size (ne00) is a multiple of 4
|
| 652 |
+
kernel void kernel_mul_mat_f16_f32_l4(
|
| 653 |
+
device const char * src0,
|
| 654 |
+
device const char * src1,
|
| 655 |
+
device float * dst,
|
| 656 |
+
constant int64_t & ne00,
|
| 657 |
+
constant int64_t & ne01,
|
| 658 |
+
constant int64_t & ne02,
|
| 659 |
+
constant uint64_t & nb00,
|
| 660 |
+
constant uint64_t & nb01,
|
| 661 |
+
constant uint64_t & nb02,
|
| 662 |
+
constant int64_t & ne10,
|
| 663 |
+
constant int64_t & ne11,
|
| 664 |
+
constant int64_t & ne12,
|
| 665 |
+
constant uint64_t & nb10,
|
| 666 |
+
constant uint64_t & nb11,
|
| 667 |
+
constant uint64_t & nb12,
|
| 668 |
+
constant int64_t & ne0,
|
| 669 |
+
constant int64_t & ne1,
|
| 670 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 671 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 672 |
+
|
| 673 |
+
const int nrows = ne11;
|
| 674 |
+
const int64_t r0 = tgpig.x;
|
| 675 |
+
const int64_t im = tgpig.z;
|
| 676 |
+
|
| 677 |
+
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
| 678 |
+
|
| 679 |
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
| 680 |
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
| 681 |
+
|
| 682 |
+
float sumf = 0;
|
| 683 |
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 684 |
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
float all_sum = simd_sum(sumf);
|
| 688 |
+
if (tiisg == 0) {
|
| 689 |
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
kernel void kernel_alibi_f32(
|
| 695 |
device const float * src0,
|
| 696 |
device float * dst,
|
|
|
|
| 1198 |
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
| 1199 |
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1200 |
|
| 1201 |
+
float yl[32];
|
| 1202 |
|
| 1203 |
+
const uint16_t kmask1 = 0x3030;
|
| 1204 |
const uint16_t kmask2 = 0x0f0f;
|
| 1205 |
|
| 1206 |
+
const int tid = tiisg/4;
|
| 1207 |
+
const int ix = tiisg%4;
|
| 1208 |
+
const int ip = tid/4; // 0 or 1
|
| 1209 |
+
const int il = 2*((tid%4)/2); // 0 or 2
|
| 1210 |
const int ir = tid%2;
|
| 1211 |
const int n = 8;
|
| 1212 |
const int l0 = n*ir;
|
| 1213 |
|
| 1214 |
+
// One would think that the Metal compiler would figure out that ip and il can only have
|
| 1215 |
+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
| 1216 |
+
// with these two tales.
|
| 1217 |
+
//
|
| 1218 |
+
// Possible masks for the high bit
|
| 1219 |
+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
| 1220 |
+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
| 1221 |
+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
| 1222 |
+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
| 1223 |
+
|
| 1224 |
+
// Possible masks for the low 2 bits
|
| 1225 |
+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
| 1226 |
+
|
| 1227 |
+
const ushort4 hm = mm[2*ip + il/2];
|
| 1228 |
|
| 1229 |
const int shift = 2*il;
|
| 1230 |
+
const float v1 = il == 0 ? 4.f : 64.f;
|
| 1231 |
+
const float v2 = 4.f * v1;
|
|
|
|
|
|
|
| 1232 |
|
| 1233 |
const uint16_t s_shift1 = 4*ip;
|
| 1234 |
+
const uint16_t s_shift2 = s_shift1 + il;
|
|
|
|
| 1235 |
|
| 1236 |
const int q_offset = 32*ip + l0;
|
| 1237 |
const int y_offset = 128*ip + 32*il + l0;
|
|
|
|
| 1240 |
|
| 1241 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 1242 |
|
| 1243 |
+
uint32_t scales32, aux32;
|
| 1244 |
+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
| 1245 |
+
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
| 1246 |
+
|
| 1247 |
+
float sumf1[2] = {0.f};
|
| 1248 |
+
float sumf2[2] = {0.f};
|
| 1249 |
+
for (int i = ix; i < nb; i += 4) {
|
| 1250 |
|
| 1251 |
for (int l = 0; l < 8; ++l) {
|
| 1252 |
+
yl[l+ 0] = y1[l+ 0];
|
| 1253 |
+
yl[l+ 8] = y1[l+16];
|
| 1254 |
+
yl[l+16] = y1[l+32];
|
| 1255 |
+
yl[l+24] = y1[l+48];
|
| 1256 |
}
|
| 1257 |
|
| 1258 |
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
|
|
|
| 1263 |
for (int row = 0; row < 2; ++row) {
|
| 1264 |
|
| 1265 |
const float d_all = (float)dh[0];
|
|
|
|
| 1266 |
|
| 1267 |
+
scales16[0] = a[4];
|
| 1268 |
+
scales16[1] = a[5];
|
| 1269 |
+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
| 1270 |
+
scales16[0] = a[il+0];
|
| 1271 |
+
scales16[1] = a[il+1];
|
| 1272 |
+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
| 1273 |
+
|
| 1274 |
+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
| 1275 |
for (int l = 0; l < n; l += 2) {
|
| 1276 |
+
const int32_t qs = q[l/2];
|
| 1277 |
+
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
| 1278 |
+
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
| 1279 |
+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
| 1280 |
+
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
| 1281 |
+
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
| 1282 |
+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
| 1283 |
}
|
| 1284 |
+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
| 1285 |
+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
| 1286 |
+
sumf1[row] += d1 * (scales[0] - 32);
|
| 1287 |
+
sumf2[row] += d2 * (scales[2] - 32);
|
| 1288 |
|
| 1289 |
+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
| 1290 |
for (int l = 0; l < n; l += 2) {
|
| 1291 |
+
const int32_t qs = q[l/2+8];
|
| 1292 |
+
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
| 1293 |
+
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
| 1294 |
+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
| 1295 |
+
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
| 1296 |
+
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
| 1297 |
+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
| 1298 |
}
|
| 1299 |
+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
| 1300 |
+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
| 1301 |
+
sumf1[row] += d1 * (scales[1] - 32);
|
| 1302 |
+
sumf2[row] += d2 * (scales[3] - 32);
|
| 1303 |
|
| 1304 |
q += step;
|
| 1305 |
h += step;
|
|
|
|
| 1308 |
|
| 1309 |
}
|
| 1310 |
|
| 1311 |
+
y1 += 4 * QK_K;
|
| 1312 |
|
| 1313 |
}
|
| 1314 |
|
| 1315 |
for (int row = 0; row < 2; ++row) {
|
| 1316 |
+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
| 1317 |
+
sumf1[row] = simd_sum(sumf);
|
| 1318 |
+
}
|
| 1319 |
+
if (tiisg == 0) {
|
| 1320 |
+
for (int row = 0; row < 2; ++row) {
|
| 1321 |
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
| 1322 |
}
|
| 1323 |
}
|
| 1324 |
}
|
|
|
|
| 1673 |
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
| 1674 |
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
| 1675 |
|
| 1676 |
+
float4 acc1 = {0.f};
|
| 1677 |
+
float4 acc2 = {0.f};
|
| 1678 |
for (int l = 0; l < n; ++l) {
|
| 1679 |
uint8_t h = qh[l];
|
| 1680 |
+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
| 1681 |
+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
| 1682 |
+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
| 1683 |
+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
| 1684 |
+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
| 1685 |
+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
| 1686 |
+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
| 1687 |
+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
| 1688 |
}
|
| 1689 |
const float dall = dh[0];
|
| 1690 |
const float dmin = dh[1];
|
| 1691 |
+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
| 1692 |
+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
| 1693 |
+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
| 1694 |
+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
| 1695 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 1696 |
|
| 1697 |
q1 += step;
|
|
|
|
| 1864 |
|
| 1865 |
//============================= templates and their specializations =============================
|
| 1866 |
|
| 1867 |
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
| 1868 |
+
template <typename type4x4>
|
| 1869 |
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
| 1870 |
+
float4x4 temp = *(((device float4x4 *)src));
|
| 1871 |
+
for (int i = 0; i < 16; i++){
|
| 1872 |
+
reg[i/4][i%4] = temp[i/4][i%4];
|
| 1873 |
+
}
|
| 1874 |
+
}
|
| 1875 |
+
|
| 1876 |
template <typename type4x4>
|
| 1877 |
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
| 1878 |
half4x4 temp = *(((device half4x4 *)src));
|
|
|
|
| 1884 |
template <typename type4x4>
|
| 1885 |
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 1886 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 1887 |
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 1888 |
+
const float d2 = d1 / 256.f;
|
| 1889 |
+
const float md = -8.h * xb->d;
|
| 1890 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1891 |
+
const ushort mask1 = mask0 << 8;
|
| 1892 |
|
| 1893 |
for (int i=0;i<8;i++) {
|
| 1894 |
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
| 1895 |
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
| 1896 |
}
|
| 1897 |
}
|
| 1898 |
|
| 1899 |
template <typename type4x4>
|
| 1900 |
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
| 1901 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 1902 |
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 1903 |
+
const float d2 = d1 / 256.f;
|
| 1904 |
+
const float m = xb->m;
|
| 1905 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1906 |
+
const ushort mask1 = mask0 << 8;
|
| 1907 |
|
| 1908 |
for (int i=0;i<8;i++) {
|
| 1909 |
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
| 1910 |
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
| 1911 |
}
|
| 1912 |
}
|
| 1913 |
|
|
|
|
| 1943 |
|
| 1944 |
template <typename type4x4>
|
| 1945 |
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
| 1946 |
+
const half d_all = xb->d;
|
| 1947 |
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 1948 |
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
| 1949 |
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
|
|
| 1956 |
((il/4)>0 ? 12 : 3);
|
| 1957 |
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
| 1958 |
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
| 1959 |
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
| 1960 |
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
| 1961 |
+
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
| 1962 |
+
const half ml = 4.h * dl;
|
| 1963 |
|
| 1964 |
+
il = (il/2) & 3;
|
| 1965 |
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
| 1966 |
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 1967 |
+
dl *= coef;
|
| 1968 |
|
| 1969 |
for (int i = 0; i < 16; ++i) {
|
| 1970 |
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
| 1971 |
}
|
| 1972 |
#else
|
| 1973 |
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
|
|
|
| 1982 |
#endif
|
| 1983 |
}
|
| 1984 |
|
| 1985 |
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
| 1986 |
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
| 1987 |
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
| 1988 |
+
}
|
| 1989 |
+
|
| 1990 |
template <typename type4x4>
|
| 1991 |
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
| 1992 |
+
device const uchar * q = xb->qs;
|
| 1993 |
|
| 1994 |
#if QK_K == 256
|
|
|
|
|
|
|
| 1995 |
short is = (il/4) * 2;
|
| 1996 |
q = q + (il/4) * 32 + 16 * (il&1);
|
| 1997 |
+
il = il & 3;
|
| 1998 |
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 1999 |
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
| 2000 |
+
const half min = xb->dmin;
|
| 2001 |
+
const half dl = d * sc[0];
|
| 2002 |
+
const half ml = min * sc[1];
|
| 2003 |
#else
|
| 2004 |
q = q + 16 * (il&1);
|
| 2005 |
device const uint8_t * s = xb->scales;
|
| 2006 |
device const half2 * dh = (device const half2 *)xb->d;
|
| 2007 |
const float2 d = (float2)dh[0];
|
| 2008 |
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
| 2009 |
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
| 2010 |
#endif
|
| 2011 |
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 2012 |
for (int i = 0; i < 16; ++i) {
|
|
|
|
| 2020 |
device const uint8_t * qh = xb->qh;
|
| 2021 |
|
| 2022 |
#if QK_K == 256
|
|
|
|
|
|
|
| 2023 |
short is = (il/4) * 2;
|
| 2024 |
q = q + 32 * (il/4) + 16 * (il&1);
|
| 2025 |
qh = qh + 16 * (il&1);
|
| 2026 |
uint8_t ul = 1 << (il/2);
|
| 2027 |
+
il = il & 3;
|
| 2028 |
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 2029 |
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
| 2030 |
+
const half min = xb->dmin;
|
| 2031 |
+
const half dl = d * sc[0];
|
| 2032 |
+
const half ml = min * sc[1];
|
| 2033 |
|
| 2034 |
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 2035 |
+
const half qh_val = il<2 ? 16.h : 256.h;
|
| 2036 |
for (int i = 0; i < 16; ++i) {
|
| 2037 |
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
| 2038 |
}
|
|
|
|
| 2051 |
|
| 2052 |
template <typename type4x4>
|
| 2053 |
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
| 2054 |
+
const half d_all = xb->d;
|
| 2055 |
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
| 2056 |
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
| 2057 |
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
|
|
| 2059 |
#if QK_K == 256
|
| 2060 |
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
| 2061 |
qh = qh + 32*(il/8) + 16*(il&1);
|
| 2062 |
+
half sc = scales[(il%2) + 2 * ((il/2))];
|
| 2063 |
+
il = (il/2) & 3;
|
| 2064 |
#else
|
| 2065 |
ql = ql + 16 * (il&1);
|
| 2066 |
+
half sc = scales[il];
|
| 2067 |
#endif
|
| 2068 |
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 2069 |
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
| 2070 |
+
const half coef = il>1 ? 1.f/16.h : 1.h;
|
| 2071 |
+
const half ml = d_all * sc * 32.h;
|
| 2072 |
+
const half dl = d_all * sc * coef;
|
| 2073 |
for (int i = 0; i < 16; ++i) {
|
| 2074 |
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
| 2075 |
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
| 2076 |
+
reg[i/4][i%4] = dl * q - ml;
|
|
|
|
|
|
|
|
|
|
| 2077 |
}
|
| 2078 |
}
|
| 2079 |
|
|
|
|
| 2113 |
// each block_q contains 16*nl weights
|
| 2114 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 2115 |
kernel void kernel_mul_mm(device const uchar * src0,
|
| 2116 |
+
device const uchar * src1,
|
| 2117 |
+
device float * dst,
|
| 2118 |
+
constant int64_t & ne00,
|
| 2119 |
+
constant int64_t & ne02,
|
| 2120 |
+
constant int64_t & nb01,
|
| 2121 |
+
constant int64_t & nb02,
|
| 2122 |
+
constant int64_t & ne12,
|
| 2123 |
+
constant int64_t & nb10,
|
| 2124 |
+
constant int64_t & nb11,
|
| 2125 |
+
constant int64_t & nb12,
|
| 2126 |
+
constant int64_t & ne0,
|
| 2127 |
+
constant int64_t & ne1,
|
| 2128 |
+
constant uint & gqa,
|
| 2129 |
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 2130 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2131 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 2132 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2133 |
+
|
| 2134 |
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
| 2135 |
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
| 2136 |
|
| 2137 |
const uint r0 = tgpig.y;
|
|
|
|
| 2144 |
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 2145 |
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 2146 |
|
| 2147 |
+
simdgroup_half8x8 ma[4];
|
| 2148 |
simdgroup_float8x8 mb[2];
|
| 2149 |
simdgroup_float8x8 c_res[8];
|
| 2150 |
for (int i = 0; i < 8; i++){
|
|
|
|
| 2152 |
}
|
| 2153 |
|
| 2154 |
short il = (tiitg % THREAD_PER_ROW);
|
| 2155 |
+
|
| 2156 |
+
uint offset0 = im/gqa*nb02;
|
| 2157 |
+
ushort offset1 = il/nl;
|
| 2158 |
+
|
| 2159 |
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
| 2160 |
+
device const float * y = (device const float *)(src1
|
| 2161 |
+
+ nb12 * im
|
| 2162 |
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
| 2163 |
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 2164 |
|
| 2165 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
| 2166 |
//load data and store to threadgroup memory
|
|
|
|
| 2240 |
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
| 2241 |
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
| 2242 |
|
| 2243 |
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
| 2244 |
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
| 2245 |
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
| 2246 |
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
|
|
| 2251 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 2252 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 2253 |
|
| 2254 |
+
typedef void (mat_mm_t)(
|
| 2255 |
+
device const uchar * src0,
|
| 2256 |
+
device const uchar * src1,
|
| 2257 |
+
device float * dst,
|
| 2258 |
+
constant int64_t & ne00,
|
| 2259 |
+
constant int64_t & ne02,
|
| 2260 |
+
constant int64_t & nb01,
|
| 2261 |
+
constant int64_t & nb02,
|
| 2262 |
+
constant int64_t & ne12,
|
| 2263 |
+
constant int64_t & nb10,
|
| 2264 |
+
constant int64_t & nb11,
|
| 2265 |
+
constant int64_t & nb12,
|
| 2266 |
+
constant int64_t & ne0,
|
| 2267 |
+
constant int64_t & ne1,
|
| 2268 |
+
constant uint & gqa,
|
| 2269 |
+
threadgroup uchar *, uint3, uint, uint);
|
| 2270 |
+
|
| 2271 |
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
| 2272 |
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
| 2273 |
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
| 2274 |
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
| 2275 |
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 2276 |
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 2277 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
@@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|
| 4303 |
}
|
| 4304 |
|
| 4305 |
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
| 4306 |
-
size_t nbytes
|
| 4307 |
-
|
| 4308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4309 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4310 |
return nbytes;
|
| 4311 |
}
|
| 4312 |
|
|
@@ -18345,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
| 18345 |
for (int i = 0; i < cgraph->n_leafs; i++) {
|
| 18346 |
struct ggml_tensor * node = cgraph->leafs[i];
|
| 18347 |
|
| 18348 |
-
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
| 18349 |
i,
|
| 18350 |
node->ne[0], node->ne[1],
|
| 18351 |
-
ggml_op_name(node->op)
|
|
|
|
| 18352 |
}
|
| 18353 |
|
| 18354 |
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
|
|
|
| 4303 |
}
|
| 4304 |
|
| 4305 |
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
| 4306 |
+
size_t nbytes;
|
| 4307 |
+
size_t blck_size = ggml_blck_size(tensor->type);
|
| 4308 |
+
if (blck_size == 1) {
|
| 4309 |
+
nbytes = ggml_type_size(tensor->type);
|
| 4310 |
+
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
| 4311 |
+
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
| 4312 |
+
}
|
| 4313 |
}
|
| 4314 |
+
else {
|
| 4315 |
+
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
| 4316 |
+
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
| 4317 |
+
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
| 4318 |
+
}
|
| 4319 |
+
}
|
| 4320 |
+
|
| 4321 |
return nbytes;
|
| 4322 |
}
|
| 4323 |
|
|
|
|
| 18356 |
for (int i = 0; i < cgraph->n_leafs; i++) {
|
| 18357 |
struct ggml_tensor * node = cgraph->leafs[i];
|
| 18358 |
|
| 18359 |
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
|
| 18360 |
i,
|
| 18361 |
node->ne[0], node->ne[1],
|
| 18362 |
+
ggml_op_name(node->op),
|
| 18363 |
+
ggml_get_name(node));
|
| 18364 |
}
|
| 18365 |
|
| 18366 |
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
|
@@ -3,11 +3,16 @@
|
|
| 3 |
#include "coreml/whisper-encoder.h"
|
| 4 |
#endif
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
#ifdef WHISPER_USE_OPENVINO
|
| 7 |
#include "openvino/whisper-openvino-encoder.h"
|
| 8 |
#endif
|
| 9 |
|
| 10 |
#include "ggml.h"
|
|
|
|
| 11 |
|
| 12 |
#include <algorithm>
|
| 13 |
#include <cassert>
|
|
@@ -24,6 +29,7 @@
|
|
| 24 |
#include <vector>
|
| 25 |
#include <regex>
|
| 26 |
#include <random>
|
|
|
|
| 27 |
|
| 28 |
#if defined(_MSC_VER)
|
| 29 |
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -115,9 +121,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
| 115 |
//#define WHISPER_USE_FLASH_FF
|
| 116 |
#define WHISPER_MAX_DECODERS 16
|
| 117 |
|
| 118 |
-
#define WHISPER_USE_SCRATCH
|
| 119 |
-
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
| 120 |
-
|
| 121 |
//
|
| 122 |
// ggml helpers
|
| 123 |
//
|
|
@@ -133,6 +136,44 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
|
|
| 133 |
ggml_graph_compute(graph, &plan);
|
| 134 |
}
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
// available whisper models
|
| 137 |
enum e_model {
|
| 138 |
MODEL_UNKNOWN,
|
|
@@ -247,38 +288,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
| 247 |
|
| 248 |
static const size_t MB = 1ull*1024*1024;
|
| 249 |
|
| 250 |
-
|
| 251 |
-
{ MODEL_TINY, 62ull*MB },
|
| 252 |
-
{ MODEL_BASE, 80ull*MB },
|
| 253 |
-
{ MODEL_SMALL, 120ull*MB },
|
| 254 |
-
{ MODEL_MEDIUM, 158ull*MB },
|
| 255 |
-
{ MODEL_LARGE, 198ull*MB },
|
| 256 |
-
};
|
| 257 |
-
|
| 258 |
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
| 259 |
-
{ MODEL_TINY, 18ull*MB },
|
| 260 |
-
{ MODEL_BASE, 24ull*MB },
|
| 261 |
-
{ MODEL_SMALL, 36ull*MB },
|
| 262 |
-
{ MODEL_MEDIUM, 48ull*MB },
|
| 263 |
-
{ MODEL_LARGE, 60ull*MB },
|
| 264 |
-
};
|
| 265 |
-
|
| 266 |
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
| 267 |
-
{ MODEL_TINY, 4ull*MB },
|
| 268 |
-
{ MODEL_BASE, 4ull*MB },
|
| 269 |
-
{ MODEL_SMALL, 6ull*MB },
|
| 270 |
-
{ MODEL_MEDIUM, 7ull*MB },
|
| 271 |
-
{ MODEL_LARGE, 9ull*MB },
|
| 272 |
-
};
|
| 273 |
-
|
| 274 |
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
| 275 |
-
{ MODEL_TINY, 4ull*MB },
|
| 276 |
-
{ MODEL_BASE, 4ull*MB },
|
| 277 |
-
{ MODEL_SMALL, 6ull*MB },
|
| 278 |
-
{ MODEL_MEDIUM, 7ull*MB },
|
| 279 |
-
{ MODEL_LARGE, 9ull*MB },
|
| 280 |
-
};
|
| 281 |
-
|
| 282 |
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
| 283 |
{ GGML_TYPE_F32,
|
| 284 |
{
|
|
@@ -345,38 +355,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
| 345 |
},
|
| 346 |
};
|
| 347 |
|
| 348 |
-
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
| 349 |
-
{ MODEL_TINY, 3ull*MB },
|
| 350 |
-
{ MODEL_BASE, 6ull*MB },
|
| 351 |
-
{ MODEL_SMALL, 16ull*MB },
|
| 352 |
-
{ MODEL_MEDIUM, 43ull*MB },
|
| 353 |
-
{ MODEL_LARGE, 71ull*MB },
|
| 354 |
-
};
|
| 355 |
-
|
| 356 |
-
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
| 357 |
-
{ MODEL_TINY, 9ull*MB },
|
| 358 |
-
{ MODEL_BASE, 18ull*MB },
|
| 359 |
-
{ MODEL_SMALL, 53ull*MB },
|
| 360 |
-
{ MODEL_MEDIUM, 141ull*MB },
|
| 361 |
-
{ MODEL_LARGE, 235ull*MB },
|
| 362 |
-
};
|
| 363 |
-
|
| 364 |
-
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
| 365 |
-
{ MODEL_TINY, 30ull*MB },
|
| 366 |
-
{ MODEL_BASE, 38ull*MB },
|
| 367 |
-
{ MODEL_SMALL, 56ull*MB },
|
| 368 |
-
{ MODEL_MEDIUM, 74ull*MB },
|
| 369 |
-
{ MODEL_LARGE, 94ull*MB },
|
| 370 |
-
};
|
| 371 |
-
|
| 372 |
-
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
| 373 |
-
{ MODEL_TINY, 3ull*MB },
|
| 374 |
-
{ MODEL_BASE, 5ull*MB },
|
| 375 |
-
{ MODEL_SMALL, 10ull*MB },
|
| 376 |
-
{ MODEL_MEDIUM, 18ull*MB },
|
| 377 |
-
{ MODEL_LARGE, 27ull*MB },
|
| 378 |
-
};
|
| 379 |
-
|
| 380 |
struct whisper_mel {
|
| 381 |
int n_len;
|
| 382 |
int n_len_org;
|
|
@@ -657,15 +635,57 @@ struct kv_buf {
|
|
| 657 |
std::vector<uint8_t> v;
|
| 658 |
};
|
| 659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
struct whisper_state {
|
| 661 |
int64_t t_sample_us = 0;
|
| 662 |
int64_t t_encode_us = 0;
|
| 663 |
int64_t t_decode_us = 0;
|
|
|
|
| 664 |
int64_t t_mel_us = 0;
|
| 665 |
|
| 666 |
int32_t n_sample = 0; // number of tokens sampled
|
| 667 |
int32_t n_encode = 0; // number of encoder calls
|
| 668 |
-
int32_t n_decode = 0; // number of decoder calls
|
|
|
|
| 669 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 670 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 671 |
|
|
@@ -679,13 +699,20 @@ struct whisper_state {
|
|
| 679 |
// buffer for swapping KV caches between decoders during beam-search
|
| 680 |
std::vector<kv_buf> kv_swap_bufs;
|
| 681 |
|
| 682 |
-
//
|
| 683 |
-
std::vector<uint8_t>
|
| 684 |
-
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
|
| 687 |
-
|
| 688 |
-
|
|
|
|
| 689 |
|
| 690 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 691 |
std::vector<float> logits;
|
|
@@ -705,6 +732,10 @@ struct whisper_state {
|
|
| 705 |
whisper_coreml_context * ctx_coreml = nullptr;
|
| 706 |
#endif
|
| 707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
#ifdef WHISPER_USE_OPENVINO
|
| 709 |
whisper_openvino_context * ctx_openvino = nullptr;
|
| 710 |
#endif
|
|
@@ -717,37 +748,6 @@ struct whisper_state {
|
|
| 717 |
|
| 718 |
// [EXPERIMENTAL] speed-up techniques
|
| 719 |
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
| 720 |
-
|
| 721 |
-
void use_buf(struct ggml_context * ctx, int i) {
|
| 722 |
-
#if defined(WHISPER_USE_SCRATCH)
|
| 723 |
-
size_t last_size = 0;
|
| 724 |
-
|
| 725 |
-
if (i == -1) {
|
| 726 |
-
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
| 727 |
-
} else {
|
| 728 |
-
auto & buf = buf_scratch[i];
|
| 729 |
-
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
| 730 |
-
}
|
| 731 |
-
|
| 732 |
-
if (buf_last >= 0) {
|
| 733 |
-
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
| 734 |
-
}
|
| 735 |
-
|
| 736 |
-
buf_last = i;
|
| 737 |
-
#else
|
| 738 |
-
(void) i;
|
| 739 |
-
(void) ctx;
|
| 740 |
-
#endif
|
| 741 |
-
}
|
| 742 |
-
|
| 743 |
-
size_t get_buf_max_mem(int i) const {
|
| 744 |
-
#if defined(WHISPER_USE_SCRATCH)
|
| 745 |
-
return buf_max_size[i];
|
| 746 |
-
#else
|
| 747 |
-
(void) i;
|
| 748 |
-
return 0;
|
| 749 |
-
#endif
|
| 750 |
-
}
|
| 751 |
};
|
| 752 |
|
| 753 |
struct whisper_context {
|
|
@@ -794,10 +794,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
| 794 |
|
| 795 |
static bool kv_cache_init(
|
| 796 |
const struct whisper_hparams & hparams,
|
| 797 |
-
const size_t mem_bytes,
|
| 798 |
struct whisper_kv_cache & cache,
|
| 799 |
ggml_type wtype,
|
| 800 |
int n_ctx) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
cache.buf.resize(mem_bytes);
|
| 802 |
|
| 803 |
struct ggml_init_params params = {
|
|
@@ -813,12 +820,6 @@ static bool kv_cache_init(
|
|
| 813 |
return false;
|
| 814 |
}
|
| 815 |
|
| 816 |
-
const int n_text_state = hparams.n_text_state;
|
| 817 |
-
const int n_text_layer = hparams.n_text_layer;
|
| 818 |
-
|
| 819 |
-
const int n_mem = n_text_layer*n_ctx;
|
| 820 |
-
const int n_elements = n_text_state*n_mem;
|
| 821 |
-
|
| 822 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 823 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 824 |
|
|
@@ -961,22 +962,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 961 |
|
| 962 |
// print memory requirements
|
| 963 |
{
|
| 964 |
-
//
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
MEM_REQ_SCRATCH1.at(model.type) +
|
| 968 |
-
MEM_REQ_SCRATCH2.at(model.type) +
|
| 969 |
-
MEM_REQ_SCRATCH3.at(model.type) +
|
| 970 |
-
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
| 971 |
-
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
| 972 |
-
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
| 973 |
-
|
| 974 |
-
// this is the memory required by one decoder
|
| 975 |
-
const size_t mem_required_decoder =
|
| 976 |
-
scale*MEM_REQ_KV_SELF.at(model.type);
|
| 977 |
-
|
| 978 |
-
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
| 979 |
-
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
| 980 |
}
|
| 981 |
|
| 982 |
// initialize all memory buffers
|
|
@@ -1485,49 +1473,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1485 |
return true;
|
| 1486 |
}
|
| 1487 |
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
| 1491 |
-
// part of the transformer model and returns the encoded features
|
| 1492 |
-
//
|
| 1493 |
-
// - wctx: the model
|
| 1494 |
-
// - wstate: the state of the encoder
|
| 1495 |
-
// - n_threads: number of threads to use
|
| 1496 |
-
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1497 |
-
//
|
| 1498 |
-
static bool whisper_encode_internal(
|
| 1499 |
-
whisper_context & wctx,
|
| 1500 |
-
whisper_state & wstate,
|
| 1501 |
-
const int mel_offset,
|
| 1502 |
-
const int n_threads){
|
| 1503 |
|
| 1504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1506 |
const auto & model = wctx.model;
|
| 1507 |
const auto & mel_inp = wstate.mel;
|
| 1508 |
const auto & hparams = model.hparams;
|
| 1509 |
|
| 1510 |
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1511 |
-
const int n_state = hparams.n_audio_state;
|
| 1512 |
-
const int n_head = hparams.n_audio_head;
|
| 1513 |
-
const int n_layer = hparams.n_audio_layer;
|
| 1514 |
|
| 1515 |
const int n_mels = hparams.n_mels;
|
| 1516 |
-
assert(mel_inp.n_mel == n_mels);
|
| 1517 |
|
| 1518 |
struct ggml_init_params params = {
|
| 1519 |
-
/*.mem_size =*/ wstate.
|
| 1520 |
-
/*.mem_buffer =*/ wstate.
|
| 1521 |
-
/*.no_alloc =*/
|
| 1522 |
};
|
| 1523 |
|
| 1524 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1525 |
|
| 1526 |
-
|
|
|
|
|
|
|
| 1527 |
|
| 1528 |
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
|
|
|
|
|
| 1529 |
assert(mel->type == GGML_TYPE_F32);
|
| 1530 |
-
{
|
|
|
|
|
|
|
| 1531 |
float * dst = (float *) mel->data;
|
| 1532 |
memset(dst, 0, ggml_nbytes(mel));
|
| 1533 |
|
|
@@ -1541,25 +1536,11 @@ static bool whisper_encode_internal(
|
|
| 1541 |
}
|
| 1542 |
}
|
| 1543 |
|
| 1544 |
-
struct ggml_tensor * cur;
|
| 1545 |
-
|
| 1546 |
-
#ifndef WHISPER_USE_COREML
|
| 1547 |
-
const bool use_coreml = false;
|
| 1548 |
-
#else
|
| 1549 |
-
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
| 1550 |
-
#endif
|
| 1551 |
-
|
| 1552 |
-
#ifndef WHISPER_USE_OPENVINO
|
| 1553 |
-
const bool use_openvino = false;
|
| 1554 |
-
#else
|
| 1555 |
-
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
| 1556 |
-
#endif
|
| 1557 |
|
| 1558 |
-
if (!
|
| 1559 |
// convolution + gelu
|
| 1560 |
{
|
| 1561 |
-
wstate.use_buf(ctx0, 1);
|
| 1562 |
-
|
| 1563 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
| 1564 |
cur = ggml_add(ctx0,
|
| 1565 |
ggml_repeat(ctx0,
|
|
@@ -1569,8 +1550,6 @@ static bool whisper_encode_internal(
|
|
| 1569 |
|
| 1570 |
cur = ggml_gelu(ctx0, cur);
|
| 1571 |
|
| 1572 |
-
wstate.use_buf(ctx0, 0);
|
| 1573 |
-
|
| 1574 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
| 1575 |
cur = ggml_add(ctx0,
|
| 1576 |
ggml_repeat(ctx0,
|
|
@@ -1581,371 +1560,431 @@ static bool whisper_encode_internal(
|
|
| 1581 |
cur = ggml_gelu(ctx0, cur);
|
| 1582 |
}
|
| 1583 |
|
| 1584 |
-
wstate.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1585 |
|
| 1586 |
-
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
|
|
|
|
|
|
|
|
|
| 1590 |
|
| 1591 |
-
|
|
|
|
|
|
|
|
|
|
| 1592 |
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
|
| 1596 |
-
//}
|
| 1597 |
|
| 1598 |
-
|
| 1599 |
|
| 1600 |
-
|
| 1601 |
-
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1602 |
|
| 1603 |
-
|
|
|
|
| 1604 |
|
| 1605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1606 |
|
| 1607 |
-
|
|
|
|
|
|
|
|
|
|
| 1608 |
|
| 1609 |
-
|
| 1610 |
-
|
|
|
|
|
|
|
|
|
|
| 1611 |
|
| 1612 |
-
|
| 1613 |
|
| 1614 |
-
|
| 1615 |
-
const auto & layer = model.layers_encoder[il];
|
| 1616 |
|
| 1617 |
-
|
| 1618 |
-
{
|
| 1619 |
-
wstate.use_buf(ctx0, 0);
|
| 1620 |
|
| 1621 |
-
|
|
|
|
| 1622 |
|
| 1623 |
-
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
| 1627 |
-
cur),
|
| 1628 |
-
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
| 1629 |
-
}
|
| 1630 |
|
| 1631 |
-
|
| 1632 |
-
{
|
| 1633 |
-
wstate.use_buf(ctx0, 1);
|
| 1634 |
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
|
|
|
|
| 1638 |
|
| 1639 |
-
|
| 1640 |
-
ggml_repeat(ctx0,
|
| 1641 |
-
layer.attn_q_b,
|
| 1642 |
-
Qcur),
|
| 1643 |
-
Qcur);
|
| 1644 |
|
| 1645 |
-
|
|
|
|
|
|
|
|
|
|
| 1646 |
|
| 1647 |
-
|
| 1648 |
-
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1649 |
-
layer.attn_k_w,
|
| 1650 |
-
cur);
|
| 1651 |
|
| 1652 |
-
|
|
|
|
| 1653 |
|
| 1654 |
-
|
| 1655 |
-
layer.attn_v_w,
|
| 1656 |
-
cur);
|
| 1657 |
|
| 1658 |
-
|
| 1659 |
-
|
| 1660 |
-
|
| 1661 |
-
|
| 1662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1663 |
|
| 1664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1665 |
|
| 1666 |
-
|
| 1667 |
|
| 1668 |
#ifdef WHISPER_USE_FLASH_ATTN
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
-
|
| 1676 |
-
|
| 1677 |
-
|
| 1678 |
-
|
| 1679 |
-
|
| 1680 |
-
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
| 1686 |
-
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
#else
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
-
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
|
| 1706 |
-
|
| 1707 |
-
|
| 1708 |
-
|
| 1709 |
-
|
| 1710 |
-
|
| 1711 |
-
|
| 1712 |
-
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
| 1727 |
-
);
|
| 1728 |
-
|
| 1729 |
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 1730 |
#endif
|
| 1731 |
-
|
| 1732 |
|
| 1733 |
-
|
|
|
|
|
|
|
|
|
|
| 1734 |
|
| 1735 |
-
|
| 1736 |
-
|
| 1737 |
-
|
| 1738 |
-
|
|
|
|
| 1739 |
|
| 1740 |
-
|
| 1741 |
-
|
| 1742 |
-
wstate.use_buf(ctx0, 0);
|
| 1743 |
|
| 1744 |
-
|
| 1745 |
-
|
| 1746 |
-
cur);
|
| 1747 |
|
| 1748 |
-
|
| 1749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1750 |
cur = ggml_add(ctx0,
|
| 1751 |
-
|
| 1752 |
-
|
| 1753 |
}
|
| 1754 |
|
| 1755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1756 |
|
| 1757 |
-
|
| 1758 |
-
cur = ggml_add(ctx0, cur, inpL);
|
| 1759 |
|
| 1760 |
-
|
|
|
|
| 1761 |
|
| 1762 |
-
//
|
| 1763 |
-
|
| 1764 |
-
|
| 1765 |
-
|
| 1766 |
-
wstate.use_buf(ctx0, 0);
|
| 1767 |
|
| 1768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1769 |
|
| 1770 |
-
|
| 1771 |
|
| 1772 |
-
|
| 1773 |
-
cur = ggml_add(ctx0,
|
| 1774 |
-
ggml_mul(ctx0,
|
| 1775 |
-
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
| 1776 |
-
cur),
|
| 1777 |
-
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 1778 |
-
}
|
| 1779 |
|
| 1780 |
-
|
| 1781 |
-
wstate.use_buf(ctx0, 0);
|
| 1782 |
|
| 1783 |
-
|
| 1784 |
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
| 1785 |
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
| 1786 |
-
#else
|
| 1787 |
-
wstate.use_buf(ctx0, 0);
|
| 1788 |
|
| 1789 |
-
|
| 1790 |
-
|
| 1791 |
-
layer.mlp_0_w,
|
| 1792 |
-
cur);
|
| 1793 |
|
| 1794 |
-
|
|
|
|
|
|
|
| 1795 |
|
| 1796 |
-
|
| 1797 |
-
|
| 1798 |
-
cur);
|
| 1799 |
|
| 1800 |
-
|
|
|
|
|
|
|
| 1801 |
|
| 1802 |
-
|
| 1803 |
-
cur = ggml_gelu(ctx0, cur);
|
| 1804 |
|
| 1805 |
-
|
|
|
|
|
|
|
| 1806 |
|
| 1807 |
-
|
| 1808 |
-
|
| 1809 |
-
|
| 1810 |
-
cur);
|
| 1811 |
|
| 1812 |
-
|
| 1813 |
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
#endif
|
| 1818 |
-
}
|
| 1819 |
|
| 1820 |
-
|
|
|
|
|
|
|
| 1821 |
|
| 1822 |
-
|
| 1823 |
-
|
|
|
|
| 1824 |
|
| 1825 |
-
|
| 1826 |
|
| 1827 |
-
|
| 1828 |
-
{
|
| 1829 |
-
wstate.use_buf(ctx0, 0);
|
| 1830 |
|
| 1831 |
-
|
|
|
|
| 1832 |
|
| 1833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1834 |
|
| 1835 |
-
|
| 1836 |
-
|
| 1837 |
-
|
| 1838 |
-
ggml_repeat(ctx0, model.e_ln_w, cur),
|
| 1839 |
-
cur),
|
| 1840 |
-
ggml_repeat(ctx0, model.e_ln_b, cur));
|
| 1841 |
-
}
|
| 1842 |
|
| 1843 |
-
|
| 1844 |
|
| 1845 |
-
|
| 1846 |
-
{
|
| 1847 |
-
struct ggml_cgraph gf = {};
|
| 1848 |
|
| 1849 |
-
|
| 1850 |
-
ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
|
| 1851 |
|
| 1852 |
-
|
|
|
|
| 1853 |
}
|
| 1854 |
}
|
| 1855 |
-
#ifdef WHISPER_USE_COREML
|
| 1856 |
-
else if (use_coreml) {
|
| 1857 |
-
wstate.use_buf(ctx0, -1);
|
| 1858 |
|
| 1859 |
-
|
|
|
|
|
|
|
| 1860 |
|
| 1861 |
-
|
| 1862 |
-
}
|
| 1863 |
-
#endif
|
| 1864 |
-
#ifdef WHISPER_USE_OPENVINO
|
| 1865 |
-
else if (use_openvino) {
|
| 1866 |
-
wstate.use_buf(ctx0, -1);
|
| 1867 |
|
| 1868 |
-
|
| 1869 |
|
| 1870 |
-
|
| 1871 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1872 |
}
|
| 1873 |
-
|
|
|
|
| 1874 |
#endif
|
|
|
|
| 1875 |
|
| 1876 |
-
//
|
| 1877 |
-
//{
|
| 1878 |
-
// printf("ne0 = %d\n", cur->ne[0]);
|
| 1879 |
-
// printf("ne1 = %d\n", cur->ne[1]);
|
| 1880 |
-
// for (int i = 0; i < 10; ++i) {
|
| 1881 |
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
| 1882 |
-
// }
|
| 1883 |
-
// printf("... ");
|
| 1884 |
-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
| 1885 |
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
| 1886 |
-
// }
|
| 1887 |
-
// printf("\n");
|
| 1888 |
-
//}
|
| 1889 |
-
|
| 1890 |
-
// pre-compute cross-attention memory
|
| 1891 |
{
|
| 1892 |
-
|
| 1893 |
-
|
| 1894 |
-
// TODO: hack to disconnect the encoded features from the previous graph
|
| 1895 |
-
cur->op = GGML_OP_NONE;
|
| 1896 |
-
cur->src[0] = nullptr;
|
| 1897 |
-
cur->src[1] = nullptr;
|
| 1898 |
-
|
| 1899 |
-
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1900 |
-
auto& layer = model.layers_decoder[il];
|
| 1901 |
|
| 1902 |
-
|
| 1903 |
|
| 1904 |
-
|
| 1905 |
-
layer.cross_attn_k_w,
|
| 1906 |
-
cur);
|
| 1907 |
-
|
| 1908 |
-
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
| 1909 |
-
|
| 1910 |
-
wstate.use_buf(ctx0, 1);
|
| 1911 |
-
|
| 1912 |
-
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
| 1913 |
-
layer.cross_attn_v_w,
|
| 1914 |
-
cur);
|
| 1915 |
-
|
| 1916 |
-
Vcross = ggml_add(ctx0,
|
| 1917 |
-
ggml_repeat(ctx0,
|
| 1918 |
-
layer.cross_attn_v_b,
|
| 1919 |
-
Vcross),
|
| 1920 |
-
Vcross);
|
| 1921 |
|
| 1922 |
-
|
| 1923 |
|
| 1924 |
-
|
| 1925 |
-
|
| 1926 |
-
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
|
| 1930 |
-
|
| 1931 |
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
| 1932 |
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
| 1933 |
}
|
| 1934 |
-
|
| 1935 |
-
ggml_graph_compute_helper(wstate.
|
| 1936 |
-
|
| 1937 |
}
|
| 1938 |
|
| 1939 |
-
|
| 1940 |
-
|
| 1941 |
-
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 1942 |
-
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 1943 |
-
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
| 1944 |
-
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
| 1945 |
-
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
| 1946 |
-
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 1947 |
-
|
| 1948 |
-
ggml_free(ctx0);
|
| 1949 |
|
| 1950 |
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
| 1951 |
wstate.n_encode++;
|
|
@@ -1953,26 +1992,13 @@ static bool whisper_encode_internal(
|
|
| 1953 |
return true;
|
| 1954 |
}
|
| 1955 |
|
| 1956 |
-
|
| 1957 |
-
|
| 1958 |
-
|
| 1959 |
-
|
| 1960 |
-
|
| 1961 |
-
|
| 1962 |
-
|
| 1963 |
-
// - n_tokens: number of tokens in the prompt
|
| 1964 |
-
// - n_past: number of past tokens to prefix the prompt with
|
| 1965 |
-
//
|
| 1966 |
-
static bool whisper_decode_internal(
|
| 1967 |
-
whisper_context & wctx,
|
| 1968 |
-
whisper_state & wstate,
|
| 1969 |
-
whisper_decoder & decoder,
|
| 1970 |
-
const whisper_token * tokens,
|
| 1971 |
-
const int n_tokens,
|
| 1972 |
-
const int n_past,
|
| 1973 |
-
const int n_threads) {
|
| 1974 |
-
const int64_t t_start_us = ggml_time_us();
|
| 1975 |
-
|
| 1976 |
const auto & model = wctx.model;
|
| 1977 |
const auto & hparams = model.hparams;
|
| 1978 |
|
|
@@ -1980,10 +2006,6 @@ static bool whisper_decode_internal(
|
|
| 1980 |
|
| 1981 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 1982 |
|
| 1983 |
-
auto & logits_out = wstate.logits;
|
| 1984 |
-
|
| 1985 |
-
const int n_vocab = hparams.n_vocab;
|
| 1986 |
-
|
| 1987 |
const int n_ctx = hparams.n_text_ctx;
|
| 1988 |
const int n_state = hparams.n_text_state;
|
| 1989 |
const int n_head = hparams.n_text_head;
|
|
@@ -1995,24 +2017,39 @@ static bool whisper_decode_internal(
|
|
| 1995 |
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
| 1996 |
|
| 1997 |
struct ggml_init_params params = {
|
| 1998 |
-
/*.mem_size =*/ wstate.
|
| 1999 |
-
/*.mem_buffer =*/ wstate.
|
| 2000 |
-
/*.no_alloc =*/
|
| 2001 |
};
|
| 2002 |
|
| 2003 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 2004 |
|
| 2005 |
-
|
|
|
|
|
|
|
| 2006 |
|
| 2007 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 2008 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2009 |
|
| 2010 |
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 2011 |
-
|
| 2012 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2013 |
}
|
| 2014 |
|
| 2015 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2016 |
|
| 2017 |
// token encoding + position encoding
|
| 2018 |
struct ggml_tensor * cur =
|
|
@@ -2027,16 +2064,14 @@ static bool whisper_decode_internal(
|
|
| 2027 |
|
| 2028 |
// norm
|
| 2029 |
{
|
| 2030 |
-
wstate.use_buf(ctx0, 0);
|
| 2031 |
-
|
| 2032 |
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
| 2033 |
|
| 2034 |
// cur = ln_0_w*cur + ln_0_b
|
| 2035 |
cur = ggml_add(ctx0,
|
| 2036 |
ggml_mul(ctx0,
|
| 2037 |
-
|
| 2038 |
-
|
| 2039 |
-
|
| 2040 |
}
|
| 2041 |
|
| 2042 |
// self-attention
|
|
@@ -2046,19 +2081,17 @@ static bool whisper_decode_internal(
|
|
| 2046 |
cur);
|
| 2047 |
|
| 2048 |
Qcur = ggml_add(ctx0,
|
| 2049 |
-
|
| 2050 |
-
layer.attn_q_b
|
| 2051 |
-
Qcur),
|
| 2052 |
-
Qcur);
|
| 2053 |
|
| 2054 |
-
Qcur =
|
| 2055 |
|
| 2056 |
// note: no bias for Key
|
| 2057 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 2058 |
layer.attn_k_w,
|
| 2059 |
cur);
|
| 2060 |
|
| 2061 |
-
Kcur =
|
| 2062 |
|
| 2063 |
// store key and value to memory
|
| 2064 |
{
|
|
@@ -2067,10 +2100,8 @@ static bool whisper_decode_internal(
|
|
| 2067 |
cur);
|
| 2068 |
|
| 2069 |
Vcur = ggml_add(ctx0,
|
| 2070 |
-
|
| 2071 |
-
layer.attn_v_b
|
| 2072 |
-
Vcur),
|
| 2073 |
-
Vcur);
|
| 2074 |
|
| 2075 |
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
| 2076 |
|
|
@@ -2079,42 +2110,32 @@ static bool whisper_decode_internal(
|
|
| 2079 |
( n_ctx)*ggml_element_size(kv_self.v),
|
| 2080 |
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
| 2081 |
|
| 2082 |
-
ggml_build_forward_expand(
|
| 2083 |
-
ggml_build_forward_expand(
|
| 2084 |
}
|
| 2085 |
|
| 2086 |
// ------
|
| 2087 |
|
| 2088 |
-
wstate.use_buf(ctx0, 0);
|
| 2089 |
-
|
| 2090 |
struct ggml_tensor * Q =
|
| 2091 |
ggml_permute(ctx0,
|
| 2092 |
-
|
| 2093 |
-
Qcur,
|
| 2094 |
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
| 2095 |
0, 2, 1, 3);
|
| 2096 |
|
| 2097 |
struct ggml_tensor * K =
|
| 2098 |
-
|
| 2099 |
-
|
| 2100 |
-
|
| 2101 |
-
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
wstate.use_buf(ctx0, 1);
|
| 2105 |
|
| 2106 |
// K * Q
|
| 2107 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2108 |
|
| 2109 |
-
//struct ggml_tensor * KQ_scaled =
|
| 2110 |
-
// ggml_scale_inplace(ctx0,
|
| 2111 |
-
// KQ,
|
| 2112 |
-
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
| 2113 |
-
// );
|
| 2114 |
|
| 2115 |
-
struct ggml_tensor * KQ_masked =
|
| 2116 |
|
| 2117 |
-
struct ggml_tensor * KQ_soft_max =
|
| 2118 |
|
| 2119 |
struct ggml_tensor * V =
|
| 2120 |
ggml_view_3d(ctx0, kv_self.v,
|
|
@@ -2134,36 +2155,28 @@ static bool whisper_decode_internal(
|
|
| 2134 |
|
| 2135 |
// projection
|
| 2136 |
{
|
| 2137 |
-
wstate.use_buf(ctx0, 0);
|
| 2138 |
-
|
| 2139 |
cur = ggml_mul_mat(ctx0,
|
| 2140 |
layer.attn_ln_1_w,
|
| 2141 |
cur);
|
| 2142 |
|
| 2143 |
-
wstate.use_buf(ctx0, 1);
|
| 2144 |
-
|
| 2145 |
cur = ggml_add(ctx0,
|
| 2146 |
-
|
| 2147 |
-
|
| 2148 |
}
|
| 2149 |
|
| 2150 |
-
wstate.use_buf(ctx0, 2);
|
| 2151 |
-
|
| 2152 |
// add the input
|
| 2153 |
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
| 2154 |
|
| 2155 |
// norm
|
| 2156 |
{
|
| 2157 |
-
wstate.use_buf(ctx0, 0);
|
| 2158 |
-
|
| 2159 |
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
| 2160 |
|
| 2161 |
// cur = ln_0_w*cur + ln_0_b
|
| 2162 |
cur = ggml_add(ctx0,
|
| 2163 |
ggml_mul(ctx0,
|
| 2164 |
-
|
| 2165 |
-
|
| 2166 |
-
|
| 2167 |
}
|
| 2168 |
|
| 2169 |
// cross-attention
|
|
@@ -2173,18 +2186,18 @@ static bool whisper_decode_internal(
|
|
| 2173 |
cur);
|
| 2174 |
|
| 2175 |
Qcur = ggml_add(ctx0,
|
| 2176 |
-
|
| 2177 |
-
layer.cross_attn_q_b
|
| 2178 |
-
Qcur),
|
| 2179 |
-
Qcur);
|
| 2180 |
|
| 2181 |
-
Qcur =
|
| 2182 |
|
| 2183 |
// Kcross is already scaled
|
| 2184 |
struct ggml_tensor * Kcross =
|
| 2185 |
-
|
| 2186 |
-
|
| 2187 |
-
n_state
|
|
|
|
|
|
|
| 2188 |
|
| 2189 |
//struct ggml_tensor * Vcross =
|
| 2190 |
// ggml_reshape_3d(ctx0,
|
|
@@ -2207,26 +2220,22 @@ static bool whisper_decode_internal(
|
|
| 2207 |
|
| 2208 |
struct ggml_tensor * Q =
|
| 2209 |
ggml_permute(ctx0,
|
| 2210 |
-
|
| 2211 |
-
Qcur,
|
| 2212 |
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
| 2213 |
0, 2, 1, 3);
|
| 2214 |
|
| 2215 |
-
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
| 2216 |
-
|
| 2217 |
// K * Q
|
| 2218 |
-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0,
|
| 2219 |
|
| 2220 |
//struct ggml_tensor * KQ_scaled =
|
| 2221 |
-
//
|
| 2222 |
// KQ,
|
| 2223 |
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
| 2224 |
// );
|
| 2225 |
|
| 2226 |
// no masking for cross-attention
|
| 2227 |
-
//struct ggml_tensor * KQ_masked =
|
| 2228 |
|
| 2229 |
-
struct ggml_tensor * KQ_soft_max =
|
| 2230 |
|
| 2231 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2232 |
|
|
@@ -2240,21 +2249,15 @@ static bool whisper_decode_internal(
|
|
| 2240 |
|
| 2241 |
// projection
|
| 2242 |
{
|
| 2243 |
-
wstate.use_buf(ctx0, 0);
|
| 2244 |
-
|
| 2245 |
cur = ggml_mul_mat(ctx0,
|
| 2246 |
layer.cross_attn_ln_1_w,
|
| 2247 |
cur);
|
| 2248 |
|
| 2249 |
-
wstate.use_buf(ctx0, 1);
|
| 2250 |
-
|
| 2251 |
cur = ggml_add(ctx0,
|
| 2252 |
-
|
| 2253 |
-
|
| 2254 |
}
|
| 2255 |
|
| 2256 |
-
wstate.use_buf(ctx0, 2);
|
| 2257 |
-
|
| 2258 |
// add the input
|
| 2259 |
cur = ggml_add(ctx0, cur, inpCA);
|
| 2260 |
|
|
@@ -2264,54 +2267,38 @@ static bool whisper_decode_internal(
|
|
| 2264 |
{
|
| 2265 |
// norm
|
| 2266 |
{
|
| 2267 |
-
wstate.use_buf(ctx0, 0);
|
| 2268 |
-
|
| 2269 |
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
| 2270 |
|
| 2271 |
-
wstate.use_buf(ctx0, 1);
|
| 2272 |
-
|
| 2273 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 2274 |
cur = ggml_add(ctx0,
|
| 2275 |
ggml_mul(ctx0,
|
| 2276 |
-
|
| 2277 |
-
|
| 2278 |
-
|
| 2279 |
}
|
| 2280 |
|
| 2281 |
-
wstate.use_buf(ctx0, 0);
|
| 2282 |
-
|
| 2283 |
// fully connected
|
| 2284 |
cur = ggml_mul_mat(ctx0,
|
| 2285 |
layer.mlp_0_w,
|
| 2286 |
cur);
|
| 2287 |
|
| 2288 |
-
wstate.use_buf(ctx0, 1);
|
| 2289 |
-
|
| 2290 |
cur = ggml_add(ctx0,
|
| 2291 |
-
|
| 2292 |
-
|
| 2293 |
-
|
| 2294 |
-
wstate.use_buf(ctx0, 0);
|
| 2295 |
|
| 2296 |
// GELU activation
|
| 2297 |
cur = ggml_gelu(ctx0, cur);
|
| 2298 |
|
| 2299 |
-
wstate.use_buf(ctx0, 1);
|
| 2300 |
-
|
| 2301 |
// projection
|
| 2302 |
cur = ggml_mul_mat(ctx0,
|
| 2303 |
layer.mlp_1_w,
|
| 2304 |
cur);
|
| 2305 |
|
| 2306 |
-
wstate.use_buf(ctx0, 0);
|
| 2307 |
-
|
| 2308 |
cur = ggml_add(ctx0,
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
}
|
| 2312 |
|
| 2313 |
-
wstate.use_buf(ctx0, 3);
|
| 2314 |
-
|
| 2315 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 2316 |
}
|
| 2317 |
|
|
@@ -2319,21 +2306,15 @@ static bool whisper_decode_internal(
|
|
| 2319 |
|
| 2320 |
// norm
|
| 2321 |
{
|
| 2322 |
-
wstate.use_buf(ctx0, 0);
|
| 2323 |
-
|
| 2324 |
cur = ggml_norm(ctx0, cur, hparams.eps);
|
| 2325 |
|
| 2326 |
-
wstate.use_buf(ctx0, 1);
|
| 2327 |
-
|
| 2328 |
cur = ggml_add(ctx0,
|
| 2329 |
ggml_mul(ctx0,
|
| 2330 |
-
|
| 2331 |
-
|
| 2332 |
-
|
| 2333 |
}
|
| 2334 |
|
| 2335 |
-
wstate.use_buf(ctx0, 0);
|
| 2336 |
-
|
| 2337 |
// compute logits only for the last token
|
| 2338 |
// comment this line to compute logits for all N tokens
|
| 2339 |
// might be useful in the future
|
|
@@ -2341,23 +2322,75 @@ static bool whisper_decode_internal(
|
|
| 2341 |
|
| 2342 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2343 |
|
| 2344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2345 |
|
| 2346 |
-
//
|
| 2347 |
{
|
| 2348 |
-
|
| 2349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2350 |
}
|
| 2351 |
|
| 2352 |
// extract logits for all N tokens
|
| 2353 |
-
//logits_out.resize(
|
| 2354 |
-
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*
|
| 2355 |
|
| 2356 |
// extract logits only for the last token
|
| 2357 |
logits_out.resize(n_vocab);
|
| 2358 |
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
| 2359 |
|
| 2360 |
-
if (
|
| 2361 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 2362 |
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 2363 |
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
@@ -2366,14 +2399,18 @@ static bool whisper_decode_internal(
|
|
| 2366 |
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 2367 |
}
|
| 2368 |
|
| 2369 |
-
|
| 2370 |
-
|
| 2371 |
-
|
| 2372 |
-
|
|
|
|
|
|
|
|
|
|
| 2373 |
|
| 2374 |
return true;
|
| 2375 |
}
|
| 2376 |
|
|
|
|
| 2377 |
// 500 -> 00:05.000
|
| 2378 |
// 6000 -> 01:00.000
|
| 2379 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2782,9 +2819,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2782 |
fill_sin_cos_table();
|
| 2783 |
whisper_state * state = new whisper_state;
|
| 2784 |
|
| 2785 |
-
|
| 2786 |
-
|
| 2787 |
-
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
| 2788 |
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2789 |
delete state;
|
| 2790 |
return nullptr;
|
|
@@ -2795,7 +2830,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2795 |
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2796 |
}
|
| 2797 |
|
| 2798 |
-
if (!kv_cache_init(ctx->model.hparams,
|
| 2799 |
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 2800 |
delete state;
|
| 2801 |
return nullptr;
|
|
@@ -2816,6 +2851,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2816 |
if (!state->ctx_coreml) {
|
| 2817 |
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
| 2818 |
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
|
|
| 2819 |
return nullptr;
|
| 2820 |
#endif
|
| 2821 |
} else {
|
|
@@ -2830,15 +2866,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2830 |
// TAGS: WHISPER_DECODER_INIT
|
| 2831 |
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
| 2832 |
|
| 2833 |
-
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
| 2834 |
-
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
| 2835 |
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
| 2836 |
-
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
| 2837 |
|
| 2838 |
-
|
| 2839 |
-
|
| 2840 |
-
|
| 2841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2842 |
|
| 2843 |
state->rng = std::mt19937(0);
|
| 2844 |
|
|
@@ -2895,7 +3027,6 @@ int whisper_ctx_init_openvino_encoder(
|
|
| 2895 |
}
|
| 2896 |
|
| 2897 |
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
| 2898 |
-
|
| 2899 |
log("%s: loading model from '%s'\n", __func__, path_model);
|
| 2900 |
|
| 2901 |
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
@@ -3048,6 +3179,13 @@ void whisper_free_state(struct whisper_state * state)
|
|
| 3048 |
}
|
| 3049 |
#endif
|
| 3050 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3051 |
#ifdef WHISPER_USE_OPENVINO
|
| 3052 |
if (state->ctx_openvino != nullptr) {
|
| 3053 |
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3055,6 +3193,11 @@ void whisper_free_state(struct whisper_state * state)
|
|
| 3055 |
}
|
| 3056 |
#endif
|
| 3057 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3058 |
delete state;
|
| 3059 |
}
|
| 3060 |
}
|
|
@@ -3475,12 +3618,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
| 3475 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
| 3476 |
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
| 3477 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
|
|
| 3478 |
|
| 3479 |
log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
| 3480 |
log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
| 3481 |
log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3482 |
log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3483 |
log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
|
|
| 3484 |
}
|
| 3485 |
log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 3486 |
}
|
|
@@ -3490,6 +3635,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
| 3490 |
ctx->state->t_sample_us = 0;
|
| 3491 |
ctx->state->t_encode_us = 0;
|
| 3492 |
ctx->state->t_decode_us = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3493 |
}
|
| 3494 |
}
|
| 3495 |
|
|
@@ -4339,6 +4489,21 @@ int whisper_full_with_state(
|
|
| 4339 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 4340 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 4341 |
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4342 |
}
|
| 4343 |
}
|
| 4344 |
|
|
@@ -4531,8 +4696,8 @@ int whisper_full_with_state(
|
|
| 4531 |
|
| 4532 |
decoder.kv_self.n += prompt.size();
|
| 4533 |
|
| 4534 |
-
memcpy(decoder.probs.data(),
|
| 4535 |
-
memcpy(decoder.logits.data(),
|
| 4536 |
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
| 4537 |
}
|
| 4538 |
|
|
@@ -5045,6 +5210,12 @@ int whisper_full_parallel(
|
|
| 5045 |
ctx->state->t_sample_us += states[i]->t_sample_us;
|
| 5046 |
ctx->state->t_encode_us += states[i]->t_encode_us;
|
| 5047 |
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5048 |
|
| 5049 |
whisper_free_state(states[i]);
|
| 5050 |
}
|
|
@@ -5241,8 +5412,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
| 5241 |
// b: N*N*sizeof(float)
|
| 5242 |
// c: N*N*sizeof(float)
|
| 5243 |
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
| 5244 |
-
std::vector<uint8_t> buf
|
| 5245 |
-
std::vector<uint8_t> work
|
| 5246 |
|
| 5247 |
// put a bunch of random data in the buffer
|
| 5248 |
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
|
|
|
| 3 |
#include "coreml/whisper-encoder.h"
|
| 4 |
#endif
|
| 5 |
|
| 6 |
+
#ifdef GGML_USE_METAL
|
| 7 |
+
# include "ggml-metal.h"
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
#ifdef WHISPER_USE_OPENVINO
|
| 11 |
#include "openvino/whisper-openvino-encoder.h"
|
| 12 |
#endif
|
| 13 |
|
| 14 |
#include "ggml.h"
|
| 15 |
+
#include "ggml-alloc.h"
|
| 16 |
|
| 17 |
#include <algorithm>
|
| 18 |
#include <cassert>
|
|
|
|
| 29 |
#include <vector>
|
| 30 |
#include <regex>
|
| 31 |
#include <random>
|
| 32 |
+
#include <functional>
|
| 33 |
|
| 34 |
#if defined(_MSC_VER)
|
| 35 |
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
|
|
| 121 |
//#define WHISPER_USE_FLASH_FF
|
| 122 |
#define WHISPER_MAX_DECODERS 16
|
| 123 |
|
|
|
|
|
|
|
|
|
|
| 124 |
//
|
| 125 |
// ggml helpers
|
| 126 |
//
|
|
|
|
| 136 |
ggml_graph_compute(graph, &plan);
|
| 137 |
}
|
| 138 |
|
| 139 |
+
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
| 140 |
+
// the idea is to represent the original matrix multiplication:
|
| 141 |
+
//
|
| 142 |
+
// Z = X @ Y
|
| 143 |
+
//
|
| 144 |
+
// with the sum of two matrix multiplications:
|
| 145 |
+
//
|
| 146 |
+
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
|
| 147 |
+
//
|
| 148 |
+
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
| 149 |
+
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
| 150 |
+
// general-purpose kernels
|
| 151 |
+
//
|
| 152 |
+
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
|
| 153 |
+
// use padding only if dimension 0 is at least 8 times larger than the padding
|
| 154 |
+
// else we won't get much benefit from the optimization
|
| 155 |
+
const int n_pad_req = 8;
|
| 156 |
+
|
| 157 |
+
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
|
| 158 |
+
return ggml_mul_mat(ctx, x, y);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
|
| 162 |
+
struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
|
| 163 |
+
|
| 164 |
+
struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
|
| 165 |
+
struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
|
| 166 |
+
|
| 167 |
+
return ggml_add(ctx,
|
| 168 |
+
ggml_mul_mat(ctx, x_0, y_0),
|
| 169 |
+
ggml_mul_mat(ctx, x_1, y_1));
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// TODO: check if other platforms can benefit from this optimization
|
| 173 |
+
#if defined(GGML_USE_METAL)
|
| 174 |
+
#define ggml_mul_mat ggml_mul_mat_pad
|
| 175 |
+
#endif
|
| 176 |
+
|
| 177 |
// available whisper models
|
| 178 |
enum e_model {
|
| 179 |
MODEL_UNKNOWN,
|
|
|
|
| 288 |
|
| 289 |
static const size_t MB = 1ull*1024*1024;
|
| 290 |
|
| 291 |
+
// TODO: avoid using GGUF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
| 293 |
{ GGML_TYPE_F32,
|
| 294 |
{
|
|
|
|
| 355 |
},
|
| 356 |
};
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
struct whisper_mel {
|
| 359 |
int n_len;
|
| 360 |
int n_len_org;
|
|
|
|
| 635 |
std::vector<uint8_t> v;
|
| 636 |
};
|
| 637 |
|
| 638 |
+
// ggml_allocr wrapper for whisper usage
|
| 639 |
+
struct whisper_allocr {
|
| 640 |
+
ggml_allocr * alloc = nullptr;
|
| 641 |
+
|
| 642 |
+
std::vector<uint8_t> meta;
|
| 643 |
+
std::vector<uint8_t> data;
|
| 644 |
+
};
|
| 645 |
+
|
| 646 |
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
| 647 |
+
return allocr.meta.size() + allocr.data.size();
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
| 651 |
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
|
| 652 |
+
const int tensor_alignment = 32;
|
| 653 |
+
|
| 654 |
+
auto & alloc = allocr.alloc;
|
| 655 |
+
auto & meta = allocr.meta;
|
| 656 |
+
auto & data = allocr.data;
|
| 657 |
+
|
| 658 |
+
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
| 659 |
+
|
| 660 |
+
alloc = ggml_allocr_new_measure(tensor_alignment);
|
| 661 |
+
|
| 662 |
+
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
|
| 663 |
+
|
| 664 |
+
ggml_allocr_free(alloc);
|
| 665 |
+
|
| 666 |
+
data.resize(alloc_size);
|
| 667 |
+
|
| 668 |
+
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
| 672 |
+
if (allocr.alloc) {
|
| 673 |
+
ggml_allocr_free(allocr.alloc);
|
| 674 |
+
allocr.alloc = nullptr;
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
struct whisper_state {
|
| 679 |
int64_t t_sample_us = 0;
|
| 680 |
int64_t t_encode_us = 0;
|
| 681 |
int64_t t_decode_us = 0;
|
| 682 |
+
int64_t t_prompt_us = 0;
|
| 683 |
int64_t t_mel_us = 0;
|
| 684 |
|
| 685 |
int32_t n_sample = 0; // number of tokens sampled
|
| 686 |
int32_t n_encode = 0; // number of encoder calls
|
| 687 |
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
| 688 |
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
| 689 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 690 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 691 |
|
|
|
|
| 699 |
// buffer for swapping KV caches between decoders during beam-search
|
| 700 |
std::vector<kv_buf> kv_swap_bufs;
|
| 701 |
|
| 702 |
+
// reusable buffer for `struct ggml_graph_plan.work_data`
|
| 703 |
+
std::vector<uint8_t> work_buffer;
|
| 704 |
+
|
| 705 |
+
// ggml-alloc:
|
| 706 |
+
// - stores meta info about the intermediate tensors into the `meta` buffers
|
| 707 |
+
// - stores the actual tensor data into the `data` buffers
|
| 708 |
+
whisper_allocr alloc_conv;
|
| 709 |
+
whisper_allocr alloc_encode;
|
| 710 |
+
whisper_allocr alloc_cross;
|
| 711 |
+
whisper_allocr alloc_decode;
|
| 712 |
|
| 713 |
+
// result of the encoder
|
| 714 |
+
struct ggml_tensor * embd_conv = nullptr;
|
| 715 |
+
struct ggml_tensor * embd_enc = nullptr;
|
| 716 |
|
| 717 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 718 |
std::vector<float> logits;
|
|
|
|
| 732 |
whisper_coreml_context * ctx_coreml = nullptr;
|
| 733 |
#endif
|
| 734 |
|
| 735 |
+
#ifdef GGML_USE_METAL
|
| 736 |
+
ggml_metal_context * ctx_metal = nullptr;
|
| 737 |
+
#endif
|
| 738 |
+
|
| 739 |
#ifdef WHISPER_USE_OPENVINO
|
| 740 |
whisper_openvino_context * ctx_openvino = nullptr;
|
| 741 |
#endif
|
|
|
|
| 748 |
|
| 749 |
// [EXPERIMENTAL] speed-up techniques
|
| 750 |
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
};
|
| 752 |
|
| 753 |
struct whisper_context {
|
|
|
|
| 794 |
|
| 795 |
static bool kv_cache_init(
|
| 796 |
const struct whisper_hparams & hparams,
|
|
|
|
| 797 |
struct whisper_kv_cache & cache,
|
| 798 |
ggml_type wtype,
|
| 799 |
int n_ctx) {
|
| 800 |
+
const int64_t n_text_state = hparams.n_text_state;
|
| 801 |
+
const int64_t n_text_layer = hparams.n_text_layer;
|
| 802 |
+
|
| 803 |
+
const int64_t n_mem = n_text_layer*n_ctx;
|
| 804 |
+
const int64_t n_elements = n_text_state*n_mem;
|
| 805 |
+
|
| 806 |
+
const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
|
| 807 |
+
|
| 808 |
cache.buf.resize(mem_bytes);
|
| 809 |
|
| 810 |
struct ggml_init_params params = {
|
|
|
|
| 820 |
return false;
|
| 821 |
}
|
| 822 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 824 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 825 |
|
|
|
|
| 962 |
|
| 963 |
// print memory requirements
|
| 964 |
{
|
| 965 |
+
// TODO
|
| 966 |
+
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
| 967 |
+
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
}
|
| 969 |
|
| 970 |
// initialize all memory buffers
|
|
|
|
| 1473 |
return true;
|
| 1474 |
}
|
| 1475 |
|
| 1476 |
+
static bool whisper_encode_external(const whisper_state & wstate) {
|
| 1477 |
+
GGML_UNUSED(wstate);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1478 |
|
| 1479 |
+
#ifndef WHISPER_USE_COREML
|
| 1480 |
+
const bool use_coreml = false;
|
| 1481 |
+
#else
|
| 1482 |
+
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
| 1483 |
+
#endif
|
| 1484 |
+
|
| 1485 |
+
#ifndef WHISPER_USE_OPENVINO
|
| 1486 |
+
const bool use_openvino = false;
|
| 1487 |
+
#else
|
| 1488 |
+
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
| 1489 |
+
#endif
|
| 1490 |
+
|
| 1491 |
+
return use_coreml || use_openvino;
|
| 1492 |
+
}
|
| 1493 |
|
| 1494 |
+
static struct ggml_cgraph * whisper_build_graph_conv(
|
| 1495 |
+
whisper_context & wctx,
|
| 1496 |
+
whisper_state & wstate,
|
| 1497 |
+
const int mel_offset) {
|
| 1498 |
const auto & model = wctx.model;
|
| 1499 |
const auto & mel_inp = wstate.mel;
|
| 1500 |
const auto & hparams = model.hparams;
|
| 1501 |
|
| 1502 |
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1503 |
+
const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
|
|
|
|
|
|
|
| 1504 |
|
| 1505 |
const int n_mels = hparams.n_mels;
|
|
|
|
| 1506 |
|
| 1507 |
struct ggml_init_params params = {
|
| 1508 |
+
/*.mem_size =*/ wstate.alloc_conv.meta.size(),
|
| 1509 |
+
/*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
|
| 1510 |
+
/*.no_alloc =*/ true,
|
| 1511 |
};
|
| 1512 |
|
| 1513 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1514 |
|
| 1515 |
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
| 1516 |
+
|
| 1517 |
+
ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
| 1518 |
|
| 1519 |
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
| 1520 |
+
ggml_allocr_alloc(alloc, mel);
|
| 1521 |
+
|
| 1522 |
assert(mel->type == GGML_TYPE_F32);
|
| 1523 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 1524 |
+
assert(mel_inp.n_mel == n_mels);
|
| 1525 |
+
|
| 1526 |
float * dst = (float *) mel->data;
|
| 1527 |
memset(dst, 0, ggml_nbytes(mel));
|
| 1528 |
|
|
|
|
| 1536 |
}
|
| 1537 |
}
|
| 1538 |
|
| 1539 |
+
struct ggml_tensor * cur = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1540 |
|
| 1541 |
+
if (!whisper_encode_external(wstate)) {
|
| 1542 |
// convolution + gelu
|
| 1543 |
{
|
|
|
|
|
|
|
| 1544 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
| 1545 |
cur = ggml_add(ctx0,
|
| 1546 |
ggml_repeat(ctx0,
|
|
|
|
| 1550 |
|
| 1551 |
cur = ggml_gelu(ctx0, cur);
|
| 1552 |
|
|
|
|
|
|
|
| 1553 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
| 1554 |
cur = ggml_add(ctx0,
|
| 1555 |
ggml_repeat(ctx0,
|
|
|
|
| 1560 |
cur = ggml_gelu(ctx0, cur);
|
| 1561 |
}
|
| 1562 |
|
| 1563 |
+
wstate.embd_conv = cur;
|
| 1564 |
+
} else {
|
| 1565 |
+
#ifdef WHISPER_USE_COREML
|
| 1566 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
| 1567 |
+
ggml_allocr_alloc(alloc, cur);
|
| 1568 |
|
| 1569 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 1570 |
+
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
| 1571 |
+
}
|
| 1572 |
+
#endif
|
| 1573 |
+
#ifdef WHISPER_USE_OPENVINO
|
| 1574 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
| 1575 |
+
ggml_allocr_alloc(alloc, cur);
|
| 1576 |
|
| 1577 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 1578 |
+
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
| 1579 |
+
}
|
| 1580 |
+
#endif
|
| 1581 |
|
| 1582 |
+
wstate.embd_enc = cur;
|
| 1583 |
+
}
|
|
|
|
|
|
|
| 1584 |
|
| 1585 |
+
ggml_build_forward_expand(gf, cur);
|
| 1586 |
|
| 1587 |
+
ggml_free(ctx0);
|
|
|
|
| 1588 |
|
| 1589 |
+
return gf;
|
| 1590 |
+
}
|
| 1591 |
|
| 1592 |
+
static struct ggml_cgraph * whisper_build_graph_encoder(
|
| 1593 |
+
whisper_context & wctx,
|
| 1594 |
+
whisper_state & wstate) {
|
| 1595 |
+
const auto & model = wctx.model;
|
| 1596 |
+
const auto & hparams = model.hparams;
|
| 1597 |
|
| 1598 |
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1599 |
+
const int n_state = hparams.n_audio_state;
|
| 1600 |
+
const int n_head = hparams.n_audio_head;
|
| 1601 |
+
const int n_layer = hparams.n_audio_layer;
|
| 1602 |
|
| 1603 |
+
struct ggml_init_params params = {
|
| 1604 |
+
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
| 1605 |
+
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
| 1606 |
+
/*.no_alloc =*/ true,
|
| 1607 |
+
};
|
| 1608 |
|
| 1609 |
+
struct ggml_context * ctx0 = ggml_init(params);
|
| 1610 |
|
| 1611 |
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
|
|
| 1612 |
|
| 1613 |
+
ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
|
|
|
|
|
| 1614 |
|
| 1615 |
+
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1616 |
+
ggml_allocr_alloc(alloc, KQscale);
|
| 1617 |
|
| 1618 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 1619 |
+
ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
|
| 1620 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1621 |
|
| 1622 |
+
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
|
|
|
|
|
| 1623 |
|
| 1624 |
+
// ===================================================================
|
| 1625 |
+
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
| 1626 |
+
//static int iter = -1;
|
| 1627 |
+
//const int n_iter = 1500/n_ctx;
|
| 1628 |
|
| 1629 |
+
//iter = (iter + 1) % n_iter;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1630 |
|
| 1631 |
+
//if (iter == 0) {
|
| 1632 |
+
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
|
| 1633 |
+
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
|
| 1634 |
+
//}
|
| 1635 |
|
| 1636 |
+
static int iter = 0;
|
|
|
|
|
|
|
|
|
|
| 1637 |
|
| 1638 |
+
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
| 1639 |
+
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1640 |
|
| 1641 |
+
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
|
|
|
|
|
| 1642 |
|
| 1643 |
+
cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
| 1644 |
+
|
| 1645 |
+
// ===================================================================
|
| 1646 |
+
|
| 1647 |
+
// original:
|
| 1648 |
+
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
|
| 1649 |
+
|
| 1650 |
+
struct ggml_tensor * inpL = cur;
|
| 1651 |
+
|
| 1652 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 1653 |
+
const auto & layer = model.layers_encoder[il];
|
| 1654 |
+
|
| 1655 |
+
// norm
|
| 1656 |
+
{
|
| 1657 |
+
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
| 1658 |
+
|
| 1659 |
+
// cur = ln_0_w*cur + ln_0_b
|
| 1660 |
+
cur = ggml_add(ctx0,
|
| 1661 |
+
ggml_mul(ctx0, cur, layer.attn_ln_0_w),
|
| 1662 |
+
layer.attn_ln_0_b);
|
| 1663 |
+
}
|
| 1664 |
+
|
| 1665 |
+
// self-attention
|
| 1666 |
+
{
|
| 1667 |
+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1668 |
+
layer.attn_q_w,
|
| 1669 |
+
cur);
|
| 1670 |
+
|
| 1671 |
+
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
| 1672 |
+
|
| 1673 |
+
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1674 |
+
|
| 1675 |
+
// note: no bias for Key
|
| 1676 |
+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1677 |
+
layer.attn_k_w,
|
| 1678 |
+
cur);
|
| 1679 |
+
|
| 1680 |
+
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1681 |
|
| 1682 |
+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
| 1683 |
+
layer.attn_v_w,
|
| 1684 |
+
cur);
|
| 1685 |
+
|
| 1686 |
+
Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
|
| 1687 |
|
| 1688 |
+
// ------
|
| 1689 |
|
| 1690 |
#ifdef WHISPER_USE_FLASH_ATTN
|
| 1691 |
+
struct ggml_tensor * Q =
|
| 1692 |
+
ggml_permute(ctx0,
|
| 1693 |
+
ggml_cpy(ctx0,
|
| 1694 |
+
Qcur,
|
| 1695 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
| 1696 |
+
0, 2, 1, 3);
|
| 1697 |
+
|
| 1698 |
+
struct ggml_tensor * K =
|
| 1699 |
+
ggml_permute(ctx0,
|
| 1700 |
+
ggml_cpy(ctx0,
|
| 1701 |
+
Kcur,
|
| 1702 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
| 1703 |
+
0, 2, 1, 3);
|
| 1704 |
+
|
| 1705 |
+
struct ggml_tensor * V =
|
| 1706 |
+
ggml_cpy(ctx0,
|
| 1707 |
+
ggml_permute(ctx0,
|
| 1708 |
+
ggml_reshape_3d(ctx0,
|
| 1709 |
+
Vcur,
|
| 1710 |
+
n_state/n_head, n_head, n_ctx),
|
| 1711 |
+
1, 2, 0, 3),
|
| 1712 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
| 1713 |
+
|
| 1714 |
+
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
| 1715 |
#else
|
| 1716 |
+
struct ggml_tensor * Q =
|
| 1717 |
+
ggml_permute(ctx0,
|
| 1718 |
+
ggml_cpy(ctx0,
|
| 1719 |
+
Qcur,
|
| 1720 |
+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
| 1721 |
+
0, 2, 1, 3);
|
| 1722 |
+
|
| 1723 |
+
struct ggml_tensor * K =
|
| 1724 |
+
ggml_permute(ctx0,
|
| 1725 |
+
ggml_cpy(ctx0,
|
| 1726 |
+
Kcur,
|
| 1727 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
| 1728 |
+
0, 2, 1, 3);
|
| 1729 |
+
|
| 1730 |
+
// K * Q
|
| 1731 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1732 |
+
|
| 1733 |
+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
|
| 1734 |
+
|
| 1735 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
|
| 1736 |
+
|
| 1737 |
+
struct ggml_tensor * V =
|
| 1738 |
+
ggml_cpy(ctx0,
|
| 1739 |
+
ggml_permute(ctx0,
|
| 1740 |
+
ggml_reshape_3d(ctx0,
|
| 1741 |
+
Vcur,
|
| 1742 |
+
n_state/n_head, n_head, n_ctx),
|
| 1743 |
+
1, 2, 0, 3),
|
| 1744 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
| 1745 |
+
);
|
| 1746 |
+
|
| 1747 |
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1748 |
#endif
|
| 1749 |
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 1750 |
|
| 1751 |
+
cur = ggml_cpy(ctx0,
|
| 1752 |
+
KQV_merged,
|
| 1753 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
| 1754 |
+
}
|
| 1755 |
|
| 1756 |
+
// projection
|
| 1757 |
+
{
|
| 1758 |
+
cur = ggml_mul_mat(ctx0,
|
| 1759 |
+
layer.attn_ln_1_w,
|
| 1760 |
+
cur);
|
| 1761 |
|
| 1762 |
+
cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
|
| 1763 |
+
}
|
|
|
|
| 1764 |
|
| 1765 |
+
// add the input
|
| 1766 |
+
cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
| 1767 |
|
| 1768 |
+
struct ggml_tensor * inpFF = cur;
|
| 1769 |
|
| 1770 |
+
// feed-forward network
|
| 1771 |
+
{
|
| 1772 |
+
// norm
|
| 1773 |
+
{
|
| 1774 |
+
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
| 1775 |
+
|
| 1776 |
+
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 1777 |
cur = ggml_add(ctx0,
|
| 1778 |
+
ggml_mul(ctx0, cur, layer.mlp_ln_w),
|
| 1779 |
+
layer.mlp_ln_b);
|
| 1780 |
}
|
| 1781 |
|
| 1782 |
+
#ifdef WHISPER_USE_FLASH_FF
|
| 1783 |
+
cur = ggml_flash_ff(ctx0,
|
| 1784 |
+
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
| 1785 |
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
| 1786 |
+
#else
|
| 1787 |
+
// fully connected
|
| 1788 |
+
cur = ggml_mul_mat(ctx0,
|
| 1789 |
+
layer.mlp_0_w,
|
| 1790 |
+
cur);
|
| 1791 |
|
| 1792 |
+
cur = ggml_add(ctx0, cur, layer.mlp_0_b);
|
|
|
|
| 1793 |
|
| 1794 |
+
// GELU activation
|
| 1795 |
+
cur = ggml_gelu(ctx0, cur);
|
| 1796 |
|
| 1797 |
+
// projection
|
| 1798 |
+
cur = ggml_mul_mat(ctx0,
|
| 1799 |
+
layer.mlp_1_w,
|
| 1800 |
+
cur);
|
|
|
|
| 1801 |
|
| 1802 |
+
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
| 1803 |
+
#endif
|
| 1804 |
+
}
|
| 1805 |
+
|
| 1806 |
+
inpL = ggml_add(ctx0, cur, inpFF);
|
| 1807 |
+
}
|
| 1808 |
+
|
| 1809 |
+
cur = inpL;
|
| 1810 |
+
|
| 1811 |
+
// norm
|
| 1812 |
+
{
|
| 1813 |
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
| 1814 |
+
|
| 1815 |
+
// cur = ln_f_g*cur + ln_f_b
|
| 1816 |
+
cur = ggml_add(ctx0,
|
| 1817 |
+
ggml_mul(ctx0, cur, model.e_ln_w),
|
| 1818 |
+
model.e_ln_b);
|
| 1819 |
+
}
|
| 1820 |
+
|
| 1821 |
+
ggml_build_forward_expand(gf, cur);
|
| 1822 |
+
|
| 1823 |
+
wstate.embd_enc = cur;
|
| 1824 |
+
|
| 1825 |
+
//ggml_graph_print(gf);
|
| 1826 |
+
|
| 1827 |
+
////////////////////////////////////////////////////////////////////////////
|
| 1828 |
+
|
| 1829 |
+
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 1830 |
+
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 1831 |
+
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
| 1832 |
+
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
| 1833 |
+
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
| 1834 |
+
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 1835 |
+
|
| 1836 |
+
ggml_free(ctx0);
|
| 1837 |
+
|
| 1838 |
+
return gf;
|
| 1839 |
+
}
|
| 1840 |
+
|
| 1841 |
+
// pre-compute cross-attention memory
|
| 1842 |
+
static struct ggml_cgraph * whisper_build_graph_cross(
|
| 1843 |
+
whisper_context & wctx,
|
| 1844 |
+
whisper_state & wstate) {
|
| 1845 |
+
const auto & model = wctx.model;
|
| 1846 |
+
const auto & hparams = model.hparams;
|
| 1847 |
+
|
| 1848 |
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1849 |
+
const int n_state = hparams.n_audio_state;
|
| 1850 |
+
const int n_head = hparams.n_audio_head;
|
| 1851 |
+
|
| 1852 |
+
struct ggml_init_params params = {
|
| 1853 |
+
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
| 1854 |
+
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
| 1855 |
+
/*.no_alloc =*/ true,
|
| 1856 |
+
};
|
| 1857 |
|
| 1858 |
+
struct ggml_context * ctx0 = ggml_init(params);
|
| 1859 |
|
| 1860 |
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1861 |
|
| 1862 |
+
ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
|
|
| 1863 |
|
| 1864 |
+
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1865 |
|
| 1866 |
+
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1867 |
+
ggml_allocr_alloc(alloc, Kscale);
|
|
|
|
|
|
|
| 1868 |
|
| 1869 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 1870 |
+
ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
|
| 1871 |
+
}
|
| 1872 |
|
| 1873 |
+
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1874 |
+
auto & layer = model.layers_decoder[il];
|
|
|
|
| 1875 |
|
| 1876 |
+
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
| 1877 |
+
layer.cross_attn_k_w,
|
| 1878 |
+
cur);
|
| 1879 |
|
| 1880 |
+
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
|
|
|
| 1881 |
|
| 1882 |
+
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
| 1883 |
+
layer.cross_attn_v_w,
|
| 1884 |
+
cur);
|
| 1885 |
|
| 1886 |
+
Vcross = ggml_add(ctx0,
|
| 1887 |
+
Vcross,
|
| 1888 |
+
layer.cross_attn_v_b);
|
|
|
|
| 1889 |
|
| 1890 |
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
| 1891 |
|
| 1892 |
+
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
|
| 1893 |
+
n_state*n_ctx,
|
| 1894 |
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
|
|
|
|
|
| 1895 |
|
| 1896 |
+
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
| 1897 |
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
| 1898 |
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
| 1899 |
|
| 1900 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
| 1901 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
| 1902 |
+
}
|
| 1903 |
|
| 1904 |
+
//ggml_graph_print(gf);
|
| 1905 |
|
| 1906 |
+
ggml_free(ctx0);
|
|
|
|
|
|
|
| 1907 |
|
| 1908 |
+
return gf;
|
| 1909 |
+
}
|
| 1910 |
|
| 1911 |
+
// evaluate the encoder with the given state
|
| 1912 |
+
//
|
| 1913 |
+
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
| 1914 |
+
// part of the transformer model and returns the encoded features
|
| 1915 |
+
//
|
| 1916 |
+
// - wctx: the model
|
| 1917 |
+
// - wstate: the state of the encoder
|
| 1918 |
+
// - n_threads: number of threads to use
|
| 1919 |
+
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1920 |
+
//
|
| 1921 |
+
static bool whisper_encode_internal(
|
| 1922 |
+
whisper_context & wctx,
|
| 1923 |
+
whisper_state & wstate,
|
| 1924 |
+
const int mel_offset,
|
| 1925 |
+
const int n_threads) {
|
| 1926 |
+
const int64_t t_start_us = ggml_time_us();
|
| 1927 |
|
| 1928 |
+
// conv
|
| 1929 |
+
{
|
| 1930 |
+
auto & alloc = wstate.alloc_conv.alloc;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1931 |
|
| 1932 |
+
ggml_allocr_reset(alloc);
|
| 1933 |
|
| 1934 |
+
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
|
|
|
|
|
|
| 1935 |
|
| 1936 |
+
ggml_allocr_alloc_graph(alloc, gf);
|
|
|
|
| 1937 |
|
| 1938 |
+
if (!whisper_encode_external(wstate)) {
|
| 1939 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1940 |
}
|
| 1941 |
}
|
|
|
|
|
|
|
|
|
|
| 1942 |
|
| 1943 |
+
// encoder
|
| 1944 |
+
if (!whisper_encode_external(wstate)) {
|
| 1945 |
+
auto & alloc = wstate.alloc_encode.alloc;
|
| 1946 |
|
| 1947 |
+
ggml_allocr_reset(alloc);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1948 |
|
| 1949 |
+
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
| 1950 |
|
| 1951 |
+
ggml_allocr_alloc_graph(alloc, gf);
|
| 1952 |
+
|
| 1953 |
+
#ifdef GGML_USE_METAL
|
| 1954 |
+
if (wstate.ctx_metal) {
|
| 1955 |
+
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1956 |
+
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1957 |
+
} else {
|
| 1958 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1959 |
}
|
| 1960 |
+
#else
|
| 1961 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1962 |
#endif
|
| 1963 |
+
}
|
| 1964 |
|
| 1965 |
+
// cross
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1966 |
{
|
| 1967 |
+
auto & alloc = wstate.alloc_cross.alloc;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1968 |
|
| 1969 |
+
ggml_allocr_reset(alloc);
|
| 1970 |
|
| 1971 |
+
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1972 |
|
| 1973 |
+
ggml_allocr_alloc_graph(alloc, gf);
|
| 1974 |
|
| 1975 |
+
#ifdef GGML_USE_METAL
|
| 1976 |
+
if (wstate.ctx_metal) {
|
| 1977 |
+
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1978 |
+
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1979 |
+
} else {
|
| 1980 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
|
|
|
|
|
|
|
|
|
| 1981 |
}
|
| 1982 |
+
#else
|
| 1983 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 1984 |
+
#endif
|
| 1985 |
}
|
| 1986 |
|
| 1987 |
+
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1988 |
|
| 1989 |
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
| 1990 |
wstate.n_encode++;
|
|
|
|
| 1992 |
return true;
|
| 1993 |
}
|
| 1994 |
|
| 1995 |
+
static struct ggml_cgraph * whisper_build_graph_decoder(
|
| 1996 |
+
whisper_context & wctx,
|
| 1997 |
+
whisper_state & wstate,
|
| 1998 |
+
whisper_decoder & decoder,
|
| 1999 |
+
const whisper_token * tokens,
|
| 2000 |
+
int n_tokens,
|
| 2001 |
+
int n_past) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2002 |
const auto & model = wctx.model;
|
| 2003 |
const auto & hparams = model.hparams;
|
| 2004 |
|
|
|
|
| 2006 |
|
| 2007 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 2008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2009 |
const int n_ctx = hparams.n_text_ctx;
|
| 2010 |
const int n_state = hparams.n_text_state;
|
| 2011 |
const int n_head = hparams.n_text_head;
|
|
|
|
| 2017 |
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
| 2018 |
|
| 2019 |
struct ggml_init_params params = {
|
| 2020 |
+
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
| 2021 |
+
/*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
|
| 2022 |
+
/*.no_alloc =*/ true,
|
| 2023 |
};
|
| 2024 |
|
| 2025 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 2026 |
|
| 2027 |
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
| 2028 |
+
|
| 2029 |
+
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
| 2030 |
|
| 2031 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 2032 |
+
ggml_allocr_alloc(alloc, embd);
|
| 2033 |
+
|
| 2034 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 2035 |
+
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
| 2036 |
+
}
|
| 2037 |
|
| 2038 |
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 2039 |
+
ggml_allocr_alloc(alloc, position);
|
| 2040 |
+
|
| 2041 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 2042 |
+
for (int i = 0; i < N; ++i) {
|
| 2043 |
+
((int32_t *) position->data)[i] = n_past + i;
|
| 2044 |
+
}
|
| 2045 |
}
|
| 2046 |
|
| 2047 |
+
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 2048 |
+
ggml_allocr_alloc(alloc, KQscale);
|
| 2049 |
+
|
| 2050 |
+
if (!ggml_allocr_is_measure(alloc)) {
|
| 2051 |
+
ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
|
| 2052 |
+
}
|
| 2053 |
|
| 2054 |
// token encoding + position encoding
|
| 2055 |
struct ggml_tensor * cur =
|
|
|
|
| 2064 |
|
| 2065 |
// norm
|
| 2066 |
{
|
|
|
|
|
|
|
| 2067 |
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
| 2068 |
|
| 2069 |
// cur = ln_0_w*cur + ln_0_b
|
| 2070 |
cur = ggml_add(ctx0,
|
| 2071 |
ggml_mul(ctx0,
|
| 2072 |
+
cur,
|
| 2073 |
+
layer.attn_ln_0_w),
|
| 2074 |
+
layer.attn_ln_0_b);
|
| 2075 |
}
|
| 2076 |
|
| 2077 |
// self-attention
|
|
|
|
| 2081 |
cur);
|
| 2082 |
|
| 2083 |
Qcur = ggml_add(ctx0,
|
| 2084 |
+
Qcur,
|
| 2085 |
+
layer.attn_q_b);
|
|
|
|
|
|
|
| 2086 |
|
| 2087 |
+
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
| 2088 |
|
| 2089 |
// note: no bias for Key
|
| 2090 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 2091 |
layer.attn_k_w,
|
| 2092 |
cur);
|
| 2093 |
|
| 2094 |
+
Kcur = ggml_scale(ctx0, Kcur, KQscale);
|
| 2095 |
|
| 2096 |
// store key and value to memory
|
| 2097 |
{
|
|
|
|
| 2100 |
cur);
|
| 2101 |
|
| 2102 |
Vcur = ggml_add(ctx0,
|
| 2103 |
+
Vcur,
|
| 2104 |
+
layer.attn_v_b);
|
|
|
|
|
|
|
| 2105 |
|
| 2106 |
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
| 2107 |
|
|
|
|
| 2110 |
( n_ctx)*ggml_element_size(kv_self.v),
|
| 2111 |
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
| 2112 |
|
| 2113 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
| 2114 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
| 2115 |
}
|
| 2116 |
|
| 2117 |
// ------
|
| 2118 |
|
|
|
|
|
|
|
| 2119 |
struct ggml_tensor * Q =
|
| 2120 |
ggml_permute(ctx0,
|
| 2121 |
+
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
|
|
|
|
|
| 2122 |
0, 2, 1, 3);
|
| 2123 |
|
| 2124 |
struct ggml_tensor * K =
|
| 2125 |
+
ggml_view_3d(ctx0, kv_self.k,
|
| 2126 |
+
n_state/n_head, n_past + N, n_head,
|
| 2127 |
+
ggml_element_size(kv_self.k)*n_state,
|
| 2128 |
+
ggml_element_size(kv_self.k)*n_state/n_head,
|
| 2129 |
+
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
|
|
|
|
|
| 2130 |
|
| 2131 |
// K * Q
|
| 2132 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2133 |
|
| 2134 |
+
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2135 |
|
| 2136 |
+
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 2137 |
|
| 2138 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 2139 |
|
| 2140 |
struct ggml_tensor * V =
|
| 2141 |
ggml_view_3d(ctx0, kv_self.v,
|
|
|
|
| 2155 |
|
| 2156 |
// projection
|
| 2157 |
{
|
|
|
|
|
|
|
| 2158 |
cur = ggml_mul_mat(ctx0,
|
| 2159 |
layer.attn_ln_1_w,
|
| 2160 |
cur);
|
| 2161 |
|
|
|
|
|
|
|
| 2162 |
cur = ggml_add(ctx0,
|
| 2163 |
+
cur,
|
| 2164 |
+
layer.attn_ln_1_b);
|
| 2165 |
}
|
| 2166 |
|
|
|
|
|
|
|
| 2167 |
// add the input
|
| 2168 |
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
| 2169 |
|
| 2170 |
// norm
|
| 2171 |
{
|
|
|
|
|
|
|
| 2172 |
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
| 2173 |
|
| 2174 |
// cur = ln_0_w*cur + ln_0_b
|
| 2175 |
cur = ggml_add(ctx0,
|
| 2176 |
ggml_mul(ctx0,
|
| 2177 |
+
cur,
|
| 2178 |
+
layer.cross_attn_ln_0_w),
|
| 2179 |
+
layer.cross_attn_ln_0_b);
|
| 2180 |
}
|
| 2181 |
|
| 2182 |
// cross-attention
|
|
|
|
| 2186 |
cur);
|
| 2187 |
|
| 2188 |
Qcur = ggml_add(ctx0,
|
| 2189 |
+
Qcur,
|
| 2190 |
+
layer.cross_attn_q_b);
|
|
|
|
|
|
|
| 2191 |
|
| 2192 |
+
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
| 2193 |
|
| 2194 |
// Kcross is already scaled
|
| 2195 |
struct ggml_tensor * Kcross =
|
| 2196 |
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2197 |
+
n_state/n_head, M, n_head,
|
| 2198 |
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2199 |
+
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
| 2200 |
+
ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
|
| 2201 |
|
| 2202 |
//struct ggml_tensor * Vcross =
|
| 2203 |
// ggml_reshape_3d(ctx0,
|
|
|
|
| 2220 |
|
| 2221 |
struct ggml_tensor * Q =
|
| 2222 |
ggml_permute(ctx0,
|
| 2223 |
+
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
|
|
|
|
|
| 2224 |
0, 2, 1, 3);
|
| 2225 |
|
|
|
|
|
|
|
| 2226 |
// K * Q
|
| 2227 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
| 2228 |
|
| 2229 |
//struct ggml_tensor * KQ_scaled =
|
| 2230 |
+
// ggml_scale(ctx0,
|
| 2231 |
// KQ,
|
| 2232 |
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
| 2233 |
// );
|
| 2234 |
|
| 2235 |
// no masking for cross-attention
|
| 2236 |
+
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 2237 |
|
| 2238 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2239 |
|
| 2240 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2241 |
|
|
|
|
| 2249 |
|
| 2250 |
// projection
|
| 2251 |
{
|
|
|
|
|
|
|
| 2252 |
cur = ggml_mul_mat(ctx0,
|
| 2253 |
layer.cross_attn_ln_1_w,
|
| 2254 |
cur);
|
| 2255 |
|
|
|
|
|
|
|
| 2256 |
cur = ggml_add(ctx0,
|
| 2257 |
+
cur,
|
| 2258 |
+
layer.cross_attn_ln_1_b);
|
| 2259 |
}
|
| 2260 |
|
|
|
|
|
|
|
| 2261 |
// add the input
|
| 2262 |
cur = ggml_add(ctx0, cur, inpCA);
|
| 2263 |
|
|
|
|
| 2267 |
{
|
| 2268 |
// norm
|
| 2269 |
{
|
|
|
|
|
|
|
| 2270 |
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
| 2271 |
|
|
|
|
|
|
|
| 2272 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 2273 |
cur = ggml_add(ctx0,
|
| 2274 |
ggml_mul(ctx0,
|
| 2275 |
+
cur,
|
| 2276 |
+
layer.mlp_ln_w),
|
| 2277 |
+
layer.mlp_ln_b);
|
| 2278 |
}
|
| 2279 |
|
|
|
|
|
|
|
| 2280 |
// fully connected
|
| 2281 |
cur = ggml_mul_mat(ctx0,
|
| 2282 |
layer.mlp_0_w,
|
| 2283 |
cur);
|
| 2284 |
|
|
|
|
|
|
|
| 2285 |
cur = ggml_add(ctx0,
|
| 2286 |
+
cur,
|
| 2287 |
+
layer.mlp_0_b);
|
|
|
|
|
|
|
| 2288 |
|
| 2289 |
// GELU activation
|
| 2290 |
cur = ggml_gelu(ctx0, cur);
|
| 2291 |
|
|
|
|
|
|
|
| 2292 |
// projection
|
| 2293 |
cur = ggml_mul_mat(ctx0,
|
| 2294 |
layer.mlp_1_w,
|
| 2295 |
cur);
|
| 2296 |
|
|
|
|
|
|
|
| 2297 |
cur = ggml_add(ctx0,
|
| 2298 |
+
cur,
|
| 2299 |
+
layer.mlp_1_b);
|
| 2300 |
}
|
| 2301 |
|
|
|
|
|
|
|
| 2302 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 2303 |
}
|
| 2304 |
|
|
|
|
| 2306 |
|
| 2307 |
// norm
|
| 2308 |
{
|
|
|
|
|
|
|
| 2309 |
cur = ggml_norm(ctx0, cur, hparams.eps);
|
| 2310 |
|
|
|
|
|
|
|
| 2311 |
cur = ggml_add(ctx0,
|
| 2312 |
ggml_mul(ctx0,
|
| 2313 |
+
cur,
|
| 2314 |
+
model.d_ln_w),
|
| 2315 |
+
model.d_ln_b);
|
| 2316 |
}
|
| 2317 |
|
|
|
|
|
|
|
| 2318 |
// compute logits only for the last token
|
| 2319 |
// comment this line to compute logits for all N tokens
|
| 2320 |
// might be useful in the future
|
|
|
|
| 2322 |
|
| 2323 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2324 |
|
| 2325 |
+
ggml_build_forward_expand(gf, logits);
|
| 2326 |
+
|
| 2327 |
+
ggml_free(ctx0);
|
| 2328 |
+
|
| 2329 |
+
return gf;
|
| 2330 |
+
}
|
| 2331 |
+
|
| 2332 |
+
// evaluate the decoder
|
| 2333 |
+
//
|
| 2334 |
+
// given text prompt + audio features -> computes the logits for the next token
|
| 2335 |
+
//
|
| 2336 |
+
// - model: the model
|
| 2337 |
+
// - n_threads: number of threads to use
|
| 2338 |
+
// - tokens: text prompt
|
| 2339 |
+
// - n_tokens: number of tokens in the prompt
|
| 2340 |
+
// - n_past: number of past tokens to prefix the prompt with
|
| 2341 |
+
//
|
| 2342 |
+
static bool whisper_decode_internal(
|
| 2343 |
+
whisper_context & wctx,
|
| 2344 |
+
whisper_state & wstate,
|
| 2345 |
+
whisper_decoder & decoder,
|
| 2346 |
+
const whisper_token * tokens,
|
| 2347 |
+
const int n_tokens,
|
| 2348 |
+
const int n_past,
|
| 2349 |
+
const int n_threads) {
|
| 2350 |
+
const int64_t t_start_us = ggml_time_us();
|
| 2351 |
+
|
| 2352 |
+
const auto & model = wctx.model;
|
| 2353 |
+
const auto & hparams = model.hparams;
|
| 2354 |
+
|
| 2355 |
+
const int n_vocab = hparams.n_vocab;
|
| 2356 |
+
|
| 2357 |
+
auto & logits_out = wstate.logits;
|
| 2358 |
+
|
| 2359 |
+
struct ggml_tensor * logits;
|
| 2360 |
|
| 2361 |
+
// decoder
|
| 2362 |
{
|
| 2363 |
+
auto & alloc = wstate.alloc_decode.alloc;
|
| 2364 |
+
|
| 2365 |
+
ggml_allocr_reset(alloc);
|
| 2366 |
+
|
| 2367 |
+
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
|
| 2368 |
+
|
| 2369 |
+
ggml_allocr_alloc_graph(alloc, gf);
|
| 2370 |
+
|
| 2371 |
+
logits = gf->nodes[gf->n_nodes - 1];
|
| 2372 |
+
|
| 2373 |
+
#ifdef GGML_USE_METAL
|
| 2374 |
+
if (wstate.ctx_metal) {
|
| 2375 |
+
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 2376 |
+
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 2377 |
+
} else {
|
| 2378 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 2379 |
+
}
|
| 2380 |
+
#else
|
| 2381 |
+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
| 2382 |
+
#endif
|
| 2383 |
}
|
| 2384 |
|
| 2385 |
// extract logits for all N tokens
|
| 2386 |
+
//logits_out.resize(n_tokens*n_vocab);
|
| 2387 |
+
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
| 2388 |
|
| 2389 |
// extract logits only for the last token
|
| 2390 |
logits_out.resize(n_vocab);
|
| 2391 |
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
| 2392 |
|
| 2393 |
+
if (n_tokens > 1) {
|
| 2394 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
| 2395 |
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 2396 |
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
|
| 2399 |
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
| 2400 |
}
|
| 2401 |
|
| 2402 |
+
if (n_tokens == 1) {
|
| 2403 |
+
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
| 2404 |
+
wstate.n_decode++;
|
| 2405 |
+
} else {
|
| 2406 |
+
wstate.t_prompt_us += ggml_time_us() - t_start_us;
|
| 2407 |
+
wstate.n_prompt++;
|
| 2408 |
+
}
|
| 2409 |
|
| 2410 |
return true;
|
| 2411 |
}
|
| 2412 |
|
| 2413 |
+
|
| 2414 |
// 500 -> 00:05.000
|
| 2415 |
// 6000 -> 01:00.000
|
| 2416 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
|
| 2819 |
fill_sin_cos_table();
|
| 2820 |
whisper_state * state = new whisper_state;
|
| 2821 |
|
| 2822 |
+
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
|
|
|
|
|
| 2823 |
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2824 |
delete state;
|
| 2825 |
return nullptr;
|
|
|
|
| 2830 |
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2831 |
}
|
| 2832 |
|
| 2833 |
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
| 2834 |
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 2835 |
delete state;
|
| 2836 |
return nullptr;
|
|
|
|
| 2851 |
if (!state->ctx_coreml) {
|
| 2852 |
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
| 2853 |
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
| 2854 |
+
delete state;
|
| 2855 |
return nullptr;
|
| 2856 |
#endif
|
| 2857 |
} else {
|
|
|
|
| 2866 |
// TAGS: WHISPER_DECODER_INIT
|
| 2867 |
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
| 2868 |
|
| 2869 |
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
| 2870 |
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
| 2871 |
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
|
|
| 2872 |
|
| 2873 |
+
// conv allocator
|
| 2874 |
+
{
|
| 2875 |
+
whisper_allocr_graph_init(state->alloc_conv,
|
| 2876 |
+
[&]() {
|
| 2877 |
+
return whisper_build_graph_conv(*ctx, *state, 0);
|
| 2878 |
+
});
|
| 2879 |
+
|
| 2880 |
+
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
|
| 2881 |
+
}
|
| 2882 |
+
|
| 2883 |
+
// encoder allocator
|
| 2884 |
+
if (!whisper_encode_external(*state)) {
|
| 2885 |
+
whisper_allocr_graph_init(state->alloc_encode,
|
| 2886 |
+
[&]() {
|
| 2887 |
+
return whisper_build_graph_encoder(*ctx, *state);
|
| 2888 |
+
});
|
| 2889 |
+
|
| 2890 |
+
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
|
| 2891 |
+
}
|
| 2892 |
+
|
| 2893 |
+
// cross allocator
|
| 2894 |
+
{
|
| 2895 |
+
whisper_allocr_graph_init(state->alloc_cross,
|
| 2896 |
+
[&]() {
|
| 2897 |
+
return whisper_build_graph_cross(*ctx, *state);
|
| 2898 |
+
});
|
| 2899 |
+
|
| 2900 |
+
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
|
| 2901 |
+
}
|
| 2902 |
+
|
| 2903 |
+
// decoder allocator
|
| 2904 |
+
{
|
| 2905 |
+
whisper_allocr_graph_init(state->alloc_decode,
|
| 2906 |
+
[&]() {
|
| 2907 |
+
const auto & hparams = ctx->model.hparams;
|
| 2908 |
+
|
| 2909 |
+
// TODO: make sure this is the worst-case scenario
|
| 2910 |
+
const int n_tokens = hparams.n_text_ctx;
|
| 2911 |
+
const int n_past = 0;
|
| 2912 |
+
|
| 2913 |
+
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
| 2914 |
+
});
|
| 2915 |
+
|
| 2916 |
+
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
| 2917 |
+
}
|
| 2918 |
+
|
| 2919 |
+
#ifdef GGML_USE_METAL
|
| 2920 |
+
state->ctx_metal = ggml_metal_init(1);
|
| 2921 |
+
if (!state->ctx_metal) {
|
| 2922 |
+
log("%s: ggml_metal_init() failed\n", __func__);
|
| 2923 |
+
delete state;
|
| 2924 |
+
return nullptr;
|
| 2925 |
+
}
|
| 2926 |
+
|
| 2927 |
+
log("%s: Metal context initialized\n", __func__);
|
| 2928 |
+
|
| 2929 |
+
// this allocates all Metal resources and memory buffers
|
| 2930 |
+
|
| 2931 |
+
void * data_ptr = NULL;
|
| 2932 |
+
size_t data_size = 0;
|
| 2933 |
+
|
| 2934 |
+
// TODO: add mmap support
|
| 2935 |
+
//if (params.use_mmap) {
|
| 2936 |
+
// data_ptr = ctx->model.mapping->addr;
|
| 2937 |
+
// data_size = ctx->model.mapping->size;
|
| 2938 |
+
//} else {
|
| 2939 |
+
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
| 2940 |
+
// data_size = ggml_get_mem_size (ctx->model.ctx);
|
| 2941 |
+
//}
|
| 2942 |
+
|
| 2943 |
+
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
| 2944 |
+
data_size = ggml_get_mem_size (ctx->model.ctx);
|
| 2945 |
+
|
| 2946 |
+
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
|
| 2947 |
+
|
| 2948 |
+
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
| 2949 |
+
|
| 2950 |
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
| 2951 |
+
if (!(result)) { \
|
| 2952 |
+
log("%s: failed to add metal buffer\n", __func__); \
|
| 2953 |
+
delete state; \
|
| 2954 |
+
return nullptr; \
|
| 2955 |
+
}
|
| 2956 |
+
|
| 2957 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
| 2958 |
+
|
| 2959 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
| 2960 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
| 2961 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
| 2962 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
| 2963 |
+
|
| 2964 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
| 2965 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
| 2966 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
| 2967 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
| 2968 |
+
|
| 2969 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
| 2970 |
+
|
| 2971 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
| 2972 |
+
#undef WHISPER_METAL_CHECK_BUF
|
| 2973 |
+
#endif
|
| 2974 |
|
| 2975 |
state->rng = std::mt19937(0);
|
| 2976 |
|
|
|
|
| 3027 |
}
|
| 3028 |
|
| 3029 |
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
|
|
| 3030 |
log("%s: loading model from '%s'\n", __func__, path_model);
|
| 3031 |
|
| 3032 |
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
|
|
| 3179 |
}
|
| 3180 |
#endif
|
| 3181 |
|
| 3182 |
+
#ifdef GGML_USE_METAL
|
| 3183 |
+
if (state->ctx_metal) {
|
| 3184 |
+
ggml_metal_free(state->ctx_metal);
|
| 3185 |
+
state->ctx_metal = nullptr;
|
| 3186 |
+
}
|
| 3187 |
+
#endif
|
| 3188 |
+
|
| 3189 |
#ifdef WHISPER_USE_OPENVINO
|
| 3190 |
if (state->ctx_openvino != nullptr) {
|
| 3191 |
whisper_openvino_free(state->ctx_openvino);
|
|
|
|
| 3193 |
}
|
| 3194 |
#endif
|
| 3195 |
|
| 3196 |
+
whisper_allocr_free(state->alloc_conv);
|
| 3197 |
+
whisper_allocr_free(state->alloc_decode);
|
| 3198 |
+
whisper_allocr_free(state->alloc_cross);
|
| 3199 |
+
whisper_allocr_free(state->alloc_encode);
|
| 3200 |
+
|
| 3201 |
delete state;
|
| 3202 |
}
|
| 3203 |
}
|
|
|
|
| 3618 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
| 3619 |
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
| 3620 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
| 3621 |
+
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
| 3622 |
|
| 3623 |
log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
| 3624 |
log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
| 3625 |
log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3626 |
log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3627 |
log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
| 3628 |
+
log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
| 3629 |
}
|
| 3630 |
log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 3631 |
}
|
|
|
|
| 3635 |
ctx->state->t_sample_us = 0;
|
| 3636 |
ctx->state->t_encode_us = 0;
|
| 3637 |
ctx->state->t_decode_us = 0;
|
| 3638 |
+
ctx->state->t_prompt_us = 0;
|
| 3639 |
+
ctx->state->n_sample = 0;
|
| 3640 |
+
ctx->state->n_encode = 0;
|
| 3641 |
+
ctx->state->n_decode = 0;
|
| 3642 |
+
ctx->state->n_prompt = 0;
|
| 3643 |
}
|
| 3644 |
}
|
| 3645 |
|
|
|
|
| 4489 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 4490 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 4491 |
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
| 4492 |
+
|
| 4493 |
+
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
|
| 4494 |
+
#ifdef GGML_USE_METAL
|
| 4495 |
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
| 4496 |
+
if (!(result)) { \
|
| 4497 |
+
log("%s: failed to add metal buffer\n", __func__); \
|
| 4498 |
+
return 0; \
|
| 4499 |
+
}
|
| 4500 |
+
|
| 4501 |
+
const std::string kv_name = "kv_self_" + std::to_string(j);
|
| 4502 |
+
auto & kv_self = decoder.kv_self;
|
| 4503 |
+
|
| 4504 |
+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
| 4505 |
+
#undef WHISPER_METAL_CHECK_BUF
|
| 4506 |
+
#endif
|
| 4507 |
}
|
| 4508 |
}
|
| 4509 |
|
|
|
|
| 4696 |
|
| 4697 |
decoder.kv_self.n += prompt.size();
|
| 4698 |
|
| 4699 |
+
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
| 4700 |
+
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
| 4701 |
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
| 4702 |
}
|
| 4703 |
|
|
|
|
| 5210 |
ctx->state->t_sample_us += states[i]->t_sample_us;
|
| 5211 |
ctx->state->t_encode_us += states[i]->t_encode_us;
|
| 5212 |
ctx->state->t_decode_us += states[i]->t_decode_us;
|
| 5213 |
+
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
| 5214 |
+
|
| 5215 |
+
ctx->state->n_sample += states[i]->n_sample;
|
| 5216 |
+
ctx->state->n_encode += states[i]->n_encode;
|
| 5217 |
+
ctx->state->n_decode += states[i]->n_decode;
|
| 5218 |
+
ctx->state->n_prompt += states[i]->n_prompt;
|
| 5219 |
|
| 5220 |
whisper_free_state(states[i]);
|
| 5221 |
}
|
|
|
|
| 5412 |
// b: N*N*sizeof(float)
|
| 5413 |
// c: N*N*sizeof(float)
|
| 5414 |
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
| 5415 |
+
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
|
| 5416 |
+
std::vector<uint8_t> work;
|
| 5417 |
|
| 5418 |
// put a bunch of random data in the buffer
|
| 5419 |
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|