ggerganov commited on
Commit
714ee6b
·
unverified ·
1 Parent(s): abbf5f2

whisper : Metal and ggml-alloc support (#1270)

Browse files

* metal : init

* whisper : factor out graph builds

* whisper : allocate encoder and decoder using ggml-alloc

* whisper : ggml-alloc is now supported

* whisper : CoreML support ggml-alloc

* build : fix ggml-alloc

* ios : update submodule

* extra : update sync-ggml.sh script to also sync ggml-alloc

* ci : see if this is causing the crash

* whisper : refactor ggml-alloc init

* whisper.android : try to fix build

* whisper : initial Metal version

* ci : try to debug vmem issue

* metal : decoder works on GPU!

* metal : add multi-decoder support

* ggml : fix ggml_nbytes (probably temp solution)

* metal : run "cross" step on the GPU

* whisper : remove ggml_repeat in the encoder

* whisper : offload the Encoder to Metal

* ggml : use simpler ggml_bytes() implementation

* ggml-alloc : try to make CI happy by reducing vram to 128GB

* whisper : add whisper_allocr to wrap ggml_allocr

* whisper : factor out alloc init in a function

* cmake : update to support Metal build

* whisper : add <functional> header

* objc : fix build (no Metal yet)

* ios : add Metal support

* swiftui : fix build

* metal : speed-up KQ multiplication

* metal : sync latest llama.cpp kernels

* readme : add Metal info

* ios : update submodule

* coreml : add code to toggle Core ML config (CPU, ANE, GPU)

* bench : fix timings by running a pre-heat

* bench : start benching the decoder

* whisper : add ggml_mul_mat_pad

* bench : fix uninitialized vars

* whisper : add comment for disabling mul-mat padding

* whisper : add description of ggml_mul_mat_pad

* whisper : clean-up ggml_mul_mat_pad

* metal : remove the "concurrent" flag

* bench : variable n_past

* ios : update SPM package

CMakeLists.txt CHANGED
@@ -1,4 +1,4 @@
1
- cmake_minimum_required (VERSION 3.0)
2
 
3
  project(whisper.cpp VERSION 1.4.2)
4
 
@@ -35,6 +35,12 @@ endif()
35
 
36
  # options
37
 
 
 
 
 
 
 
38
  option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
39
 
40
  option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
@@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
58
 
59
  if (APPLE)
60
  option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
 
 
61
  option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
62
  option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
63
  else()
@@ -113,6 +121,34 @@ if (APPLE)
113
  endif()
114
  endif()
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if (WHISPER_COREML)
117
  find_library(FOUNDATION_FRAMEWORK Foundation)
118
  find_library(COREML_FRAMEWORK CoreML)
@@ -177,7 +213,7 @@ if (WHISPER_CUBLAS)
177
 
178
  enable_language(CUDA)
179
 
180
- set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
181
 
182
  add_compile_definitions(GGML_USE_CUBLAS)
183
 
@@ -228,7 +264,7 @@ if (WHISPER_CLBLAST)
228
  if (CLBlast_FOUND)
229
  message(STATUS "CLBlast found")
230
 
231
- set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
232
 
233
  add_compile_definitions(GGML_USE_CLBLAST)
234
 
@@ -426,8 +462,11 @@ set(TARGET whisper)
426
  add_library(${TARGET}
427
  ggml.h
428
  ggml.c
429
- ${GGML_CUDA_SOURCES}
430
- ${GGML_OPENCL_SOURCES}
 
 
 
431
  whisper.h
432
  whisper.cpp
433
  )
@@ -468,9 +507,15 @@ if (BUILD_SHARED_LIBS)
468
  WHISPER_BUILD
469
  GGML_BUILD
470
  )
 
 
 
 
 
 
471
  endif()
472
 
473
- if (GGML_CUDA_SOURCES)
474
  message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
475
  set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
476
  set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
@@ -486,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC
486
 
487
  set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
488
 
 
 
489
  install(TARGETS ${TARGET}
490
- LIBRARY DESTINATION lib
491
- ARCHIVE DESTINATION lib/static
492
- RUNTIME DESTINATION bin
 
493
  PUBLIC_HEADER DESTINATION include
494
  )
495
 
 
1
+ cmake_minimum_required (VERSION 3.5)
2
 
3
  project(whisper.cpp VERSION 1.4.2)
4
 
 
35
 
36
  # options
37
 
38
+ if (APPLE)
39
+ set(WHISPER_METAL_DEFAULT ON)
40
+ else()
41
+ set(WHISPER_METAL_DEFAULT OFF)
42
+ endif()
43
+
44
  option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
45
 
46
  option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
 
64
 
65
  if (APPLE)
66
  option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
67
+ option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT})
68
+ option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
69
  option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
70
  option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
71
  else()
 
121
  endif()
122
  endif()
123
 
124
+ if (WHISPER_METAL)
125
+ find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
126
+ find_library(METAL_FRAMEWORK Metal REQUIRED)
127
+ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
128
+
129
+ if (METAL_FRAMEWORK)
130
+ message(STATUS "Metal framework found")
131
+
132
+ set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
133
+ ${FOUNDATION_LIBRARY}
134
+ ${METAL_FRAMEWORK}
135
+ ${METALKIT_FRAMEWORK}
136
+ )
137
+ set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL)
138
+
139
+ if (WHISPER_METAL_NDEBUG)
140
+ set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
141
+ endif()
142
+ else()
143
+ message(WARNING "Metal framework not found")
144
+ endif()
145
+
146
+ set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
147
+
148
+ # copy ggml-metal.metal to bin directory
149
+ configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
150
+ endif()
151
+
152
  if (WHISPER_COREML)
153
  find_library(FOUNDATION_FRAMEWORK Foundation)
154
  find_library(COREML_FRAMEWORK CoreML)
 
213
 
214
  enable_language(CUDA)
215
 
216
+ set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
217
 
218
  add_compile_definitions(GGML_USE_CUBLAS)
219
 
 
264
  if (CLBlast_FOUND)
265
  message(STATUS "CLBlast found")
266
 
267
+ set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
268
 
269
  add_compile_definitions(GGML_USE_CLBLAST)
270
 
 
462
  add_library(${TARGET}
463
  ggml.h
464
  ggml.c
465
+ ggml-alloc.h
466
+ ggml-alloc.c
467
+ ${GGML_SOURCES_METAL}
468
+ ${GGML_SOURCES_CUDA}
469
+ ${GGML_SOURCES_OPENCL}
470
  whisper.h
471
  whisper.cpp
472
  )
 
507
  WHISPER_BUILD
508
  GGML_BUILD
509
  )
510
+
511
+ if (WHISPER_METAL)
512
+ # TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory
513
+ # but for some reason it does not work here like it does in llama.cpp
514
+ set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
515
+ endif()
516
  endif()
517
 
518
+ if (GGML_SOURCES_CUDA)
519
  message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
520
  set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
521
  set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
 
531
 
532
  set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
533
 
534
+ include(GNUInstallDirs)
535
+
536
  install(TARGETS ${TARGET}
537
+ LIBRARY DESTINATION lib
538
+ ARCHIVE DESTINATION lib/static
539
+ RUNTIME DESTINATION bin
540
+ RESOURCE DESTINATION bin
541
  PUBLIC_HEADER DESTINATION include
542
  )
543
 
Makefile CHANGED
@@ -18,7 +18,7 @@ ifndef NVCC_VERSION
18
  endif
19
  endif
20
 
21
- CCV := $(shell $(CC) --version | head -n 1)
22
  CXXV := $(shell $(CXX) --version | head -n 1)
23
 
24
  # Mac OS + Arm can report x86_64
@@ -182,6 +182,15 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK
182
  endif
183
  endif
184
 
 
 
 
 
 
 
 
 
 
185
  ifdef WHISPER_OPENBLAS
186
  CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
187
  LDFLAGS += -lopenblas
@@ -288,6 +297,11 @@ $(info )
288
  ggml.o: ggml.c ggml.h ggml-cuda.h
289
  $(CC) $(CFLAGS) -c $< -o $@
290
 
 
 
 
 
 
291
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
292
  $(CXX) $(CXXFLAGS) -c $< -o $@
293
 
@@ -303,6 +317,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp
303
  WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
304
  endif
305
 
 
 
 
 
 
 
 
306
  libwhisper.a: ggml.o $(WHISPER_OBJ)
307
  $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
308
 
 
18
  endif
19
  endif
20
 
21
+ CCV := $(shell $(CC) --version | head -n 1)
22
  CXXV := $(shell $(CXX) --version | head -n 1)
23
 
24
  # Mac OS + Arm can report x86_64
 
182
  endif
183
  endif
184
 
185
+ ifndef WHISPER_NO_METAL
186
+ ifeq ($(UNAME_S),Darwin)
187
+ WHISPER_METAL := 1
188
+
189
+ CXXFLAGS += -DGGML_USE_METAL
190
+ LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
191
+ endif
192
+ endif
193
+
194
  ifdef WHISPER_OPENBLAS
195
  CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
196
  LDFLAGS += -lopenblas
 
297
  ggml.o: ggml.c ggml.h ggml-cuda.h
298
  $(CC) $(CFLAGS) -c $< -o $@
299
 
300
+ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
301
+ $(CC) $(CFLAGS) -c $< -o $@
302
+
303
+ WHISPER_OBJ += ggml-alloc.o
304
+
305
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
306
  $(CXX) $(CXXFLAGS) -c $< -o $@
307
 
 
317
  WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
318
  endif
319
 
320
+ ifdef WHISPER_METAL
321
+ ggml-metal.o: ggml-metal.m ggml-metal.h
322
+ $(CC) $(CFLAGS) -c $< -o $@
323
+
324
+ WHISPER_OBJ += ggml-metal.o
325
+ endif
326
+
327
  libwhisper.a: ggml.o $(WHISPER_OBJ)
328
  $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
329
 
README.md CHANGED
@@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S
11
  High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
12
 
13
  - Plain C/C++ implementation without dependencies
14
- - Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
15
  - AVX intrinsics support for x86 architectures
16
  - VSX intrinsics support for POWER architectures
17
  - Mixed F16 / F32 precision
18
  - [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
19
  - Low memory usage (Flash Attention)
20
  - Zero memory allocations at runtime
21
- - Runs on the CPU
22
  - [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
23
  - [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
24
  - [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
@@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command]
50
 
51
  https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
52
 
 
 
 
 
53
  Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
54
 
55
  ## Implementation details
 
11
  High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
12
 
13
  - Plain C/C++ implementation without dependencies
14
+ - Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
15
  - AVX intrinsics support for x86 architectures
16
  - VSX intrinsics support for POWER architectures
17
  - Mixed F16 / F32 precision
18
  - [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
19
  - Low memory usage (Flash Attention)
20
  - Zero memory allocations at runtime
21
+ - Support for CPU-only inference
22
  - [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
23
  - [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
24
  - [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
 
50
 
51
  https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
52
 
53
+ On Apply Silicon, the inference runs fully on the GPU via Metal:
54
+
55
+ https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
56
+
57
  Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
58
 
59
  ## Implementation details
bindings/ios CHANGED
@@ -1 +1 @@
1
- Subproject commit de46d9e7817fe851c109d66080239d415812d32a
 
1
+ Subproject commit 22a9eef021afc67f2154bc9811ed620b26299d1b
coreml/whisper-encoder.mm CHANGED
@@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
22
 
23
  NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
24
 
25
- const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
 
 
 
 
 
 
26
 
27
  if (data == NULL) {
28
  return NULL;
 
22
 
23
  NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
24
 
25
+ // select which device to run the Core ML model on
26
+ MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
27
+ config.computeUnits = MLComputeUnitsCPUAndGPU;
28
+ //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
29
+ //config.computeUnits = MLComputeUnitsAll;
30
+
31
+ const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
32
 
33
  if (data == NULL) {
34
  return NULL;
examples/bench/bench.cpp CHANGED
@@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
44
  fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
45
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
46
  fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
47
- fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
48
  fprintf(stderr, " %-7s 1 - memcpy\n", "");
49
  fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
50
  fprintf(stderr, "\n");
51
  }
52
 
53
- int whisper_bench_encoder(const whisper_params & params) {
54
  // whisper init
55
 
56
  struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@@ -69,12 +69,49 @@ int whisper_bench_encoder(const whisper_params & params) {
69
  fprintf(stderr, "error: failed to set mel: %d\n", ret);
70
  return 3;
71
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
73
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
74
  fprintf(stderr, "error: failed to encode model: %d\n", ret);
75
  return 4;
76
  }
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  whisper_print_timings(ctx);
79
  whisper_free(ctx);
80
 
@@ -103,7 +140,7 @@ int main(int argc, char ** argv) {
103
  int ret = -1;
104
 
105
  switch (params.what) {
106
- case 0: ret = whisper_bench_encoder(params); break;
107
  case 1: ret = whisper_bench_memcpy(params.n_threads); break;
108
  case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
109
  default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
 
44
  fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
45
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
46
  fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
47
+ fprintf(stderr, " %-7s 0 - whisper\n", "");
48
  fprintf(stderr, " %-7s 1 - memcpy\n", "");
49
  fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
50
  fprintf(stderr, "\n");
51
  }
52
 
53
+ int whisper_bench_full(const whisper_params & params) {
54
  // whisper init
55
 
56
  struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
69
  fprintf(stderr, "error: failed to set mel: %d\n", ret);
70
  return 3;
71
  }
72
+ // heat encoder
73
+ if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
74
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
75
+ return 4;
76
+ }
77
+
78
+ whisper_token tokens[512];
79
+ memset(tokens, 0, sizeof(tokens));
80
+
81
+ // prompt heat
82
+ if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
83
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
84
+ return 4;
85
+ }
86
+
87
+ // text-generation heat
88
+ if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
89
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
90
+ return 4;
91
+ }
92
 
93
+ whisper_reset_timings(ctx);
94
+
95
+ // actual run
96
  if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
97
  fprintf(stderr, "error: failed to encode model: %d\n", ret);
98
  return 4;
99
  }
100
 
101
+ for (int i = 0; i < 16; i++) {
102
+ if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
103
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
104
+ return 4;
105
+ }
106
+ }
107
+
108
+ for (int i = 0; i < 256; i++) {
109
+ if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
110
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
111
+ return 4;
112
+ }
113
+ }
114
+
115
  whisper_print_timings(ctx);
116
  whisper_free(ctx);
117
 
 
140
  int ret = -1;
141
 
142
  switch (params.what) {
143
+ case 0: ret = whisper_bench_full(params); break;
144
  case 1: ret = whisper_bench_memcpy(params.n_threads); break;
145
  case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
146
  default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
examples/talk-llama/CMakeLists.txt CHANGED
@@ -7,7 +7,7 @@ if (WHISPER_SDL2)
7
 
8
  # TODO: this is temporary
9
  # need to export ggml symbols for MSVC, but too lazy ..
10
- add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
11
 
12
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
13
  target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
 
7
 
8
  # TODO: this is temporary
9
  # need to export ggml symbols for MSVC, but too lazy ..
10
+ add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
11
 
12
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
13
  target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt CHANGED
@@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
8
  set(
9
  SOURCE_FILES
10
  ${WHISPER_LIB_DIR}/ggml.c
 
11
  ${WHISPER_LIB_DIR}/whisper.cpp
12
  ${CMAKE_SOURCE_DIR}/jni.c
13
  )
@@ -20,7 +21,7 @@ function(build_library target_name)
20
  SHARED
21
  ${SOURCE_FILES}
22
  )
23
-
24
  target_link_libraries(${target_name} ${LOG_LIB} android)
25
 
26
  if (${target_name} STREQUAL "whisper_v8fp16_va")
 
8
  set(
9
  SOURCE_FILES
10
  ${WHISPER_LIB_DIR}/ggml.c
11
+ ${WHISPER_LIB_DIR}/ggml-alloc.c
12
  ${WHISPER_LIB_DIR}/whisper.cpp
13
  ${CMAKE_SOURCE_DIR}/jni.c
14
  )
 
21
  SHARED
22
  ${SOURCE_FILES}
23
  )
24
+
25
  target_link_libraries(${target_name} ${LOG_LIB} android)
26
 
27
  if (${target_name} STREQUAL "whisper_v8fp16_va")
examples/whisper.objc/README.md CHANGED
@@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription:
28
 
29
  <img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
30
 
 
 
31
  If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
32
 
33
  <img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
@@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW
35
  Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
36
 
37
  In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  <img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
30
 
31
+ ## Core ML
32
+
33
  If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
34
 
35
  <img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
 
37
  Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
38
 
39
  In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
40
+
41
+ ## Metal
42
+
43
+ You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient
44
+ compared to Core ML depending on the model and device that you use.
45
+
46
+ To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready.
47
+ This will make both the Encoder and the Decoder run on the GPU.
48
+
49
+ If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all!
examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj CHANGED
@@ -7,6 +7,9 @@
7
  objects = {
8
 
9
  /* Begin PBXBuildFile section */
 
 
 
10
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
11
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
12
  18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
@@ -14,7 +17,7 @@
14
  18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
15
  18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
16
  18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
17
- 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
18
  18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
19
  18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
20
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
@@ -23,7 +26,24 @@
23
  7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
24
  /* End PBXBuildFile section */
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  /* Begin PBXFileReference section */
 
 
 
 
27
  18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
28
  18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
29
  18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
@@ -80,6 +100,10 @@
80
  18627C7829052BDF00BD2A04 /* whisper.objc */ = {
81
  isa = PBXGroup;
82
  children = (
 
 
 
 
83
  7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
84
  7FE342442A0C3FA20015A058 /* coreml */,
85
  18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
@@ -126,6 +150,7 @@
126
  18627C7229052BDF00BD2A04 /* Sources */,
127
  18627C7329052BDF00BD2A04 /* Frameworks */,
128
  18627C7429052BDF00BD2A04 /* Resources */,
 
129
  );
130
  buildRules = (
131
  );
@@ -194,8 +219,10 @@
194
  18627C9629052C5800BD2A04 /* ggml.c in Sources */,
195
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
196
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
 
197
  18627C8C29052BE000BD2A04 /* main.m in Sources */,
198
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
 
199
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
200
  );
201
  runOnlyForDeploymentPostprocessing = 0;
 
7
  objects = {
8
 
9
  /* Begin PBXBuildFile section */
10
+ 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
11
+ 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
12
+ 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
13
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
14
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
15
  18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
 
17
  18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
18
  18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
19
  18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
20
+ 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
21
  18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
22
  18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
23
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
 
26
  7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
27
  /* End PBXBuildFile section */
28
 
29
+ /* Begin PBXCopyFilesBuildPhase section */
30
+ 184447202AB21B25007D6BFE /* CopyFiles */ = {
31
+ isa = PBXCopyFilesBuildPhase;
32
+ buildActionMask = 2147483647;
33
+ dstPath = "";
34
+ dstSubfolderSpec = 7;
35
+ files = (
36
+ 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */,
37
+ );
38
+ runOnlyForDeploymentPostprocessing = 0;
39
+ };
40
+ /* End PBXCopyFilesBuildPhase section */
41
+
42
  /* Begin PBXFileReference section */
43
+ 184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = "<group>"; };
44
+ 184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = "<group>"; };
45
+ 1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = "<group>"; };
46
+ 1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = "<group>"; };
47
  18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
48
  18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
49
  18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
 
100
  18627C7829052BDF00BD2A04 /* whisper.objc */ = {
101
  isa = PBXGroup;
102
  children = (
103
+ 1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
104
+ 1844471B2AB21655007D6BFE /* ggml-metal.m */,
105
+ 184447182AB211A2007D6BFE /* ggml-alloc.c */,
106
+ 184447192AB211A2007D6BFE /* ggml-alloc.h */,
107
  7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
108
  7FE342442A0C3FA20015A058 /* coreml */,
109
  18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
 
150
  18627C7229052BDF00BD2A04 /* Sources */,
151
  18627C7329052BDF00BD2A04 /* Frameworks */,
152
  18627C7429052BDF00BD2A04 /* Resources */,
153
+ 184447202AB21B25007D6BFE /* CopyFiles */,
154
  );
155
  buildRules = (
156
  );
 
219
  18627C9629052C5800BD2A04 /* ggml.c in Sources */,
220
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
221
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
222
+ 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
223
  18627C8C29052BE000BD2A04 /* main.m in Sources */,
224
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
225
+ 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
226
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
227
  );
228
  runOnlyForDeploymentPostprocessing = 0;
examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj CHANGED
@@ -20,6 +20,7 @@
20
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
21
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
22
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
 
23
  /* End PBXBuildFile section */
24
 
25
  /* Begin PBXFileReference section */
@@ -41,6 +42,8 @@
41
  0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
42
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
43
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
 
 
44
  /* End PBXFileReference section */
45
 
46
  /* Begin PBXFrameworksBuildPhase section */
@@ -124,6 +127,8 @@
124
  0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
125
  isa = PBXGroup;
126
  children = (
 
 
127
  0AAC5DC929539EB0003032C3 /* ggml.c */,
128
  0AAC5DCA29539EB0003032C3 /* ggml.h */,
129
  0AAC5DC729539EB0003032C3 /* whisper.cpp */,
@@ -242,6 +247,7 @@
242
  0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
243
  0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
244
  0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
 
245
  );
246
  runOnlyForDeploymentPostprocessing = 0;
247
  };
@@ -369,7 +375,7 @@
369
  CODE_SIGN_STYLE = Automatic;
370
  CURRENT_PROJECT_VERSION = 1;
371
  DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
372
- DEVELOPMENT_TEAM = 3TZ9BM962G;
373
  ENABLE_HARDENED_RUNTIME = YES;
374
  ENABLE_PREVIEWS = YES;
375
  GENERATE_INFOPLIST_FILE = YES;
@@ -410,7 +416,7 @@
410
  CODE_SIGN_STYLE = Automatic;
411
  CURRENT_PROJECT_VERSION = 1;
412
  DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
413
- DEVELOPMENT_TEAM = 3TZ9BM962G;
414
  ENABLE_HARDENED_RUNTIME = YES;
415
  ENABLE_PREVIEWS = YES;
416
  GENERATE_INFOPLIST_FILE = YES;
 
20
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
21
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
22
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
23
+ 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
24
  /* End PBXBuildFile section */
25
 
26
  /* Begin PBXFileReference section */
 
42
  0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
43
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
44
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
45
+ 18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
46
+ 18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
47
  /* End PBXFileReference section */
48
 
49
  /* Begin PBXFrameworksBuildPhase section */
 
127
  0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
128
  isa = PBXGroup;
129
  children = (
130
+ 18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
131
+ 18AED4802AB21F2B009D854F /* ggml-alloc.h */,
132
  0AAC5DC929539EB0003032C3 /* ggml.c */,
133
  0AAC5DCA29539EB0003032C3 /* ggml.h */,
134
  0AAC5DC729539EB0003032C3 /* whisper.cpp */,
 
247
  0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
248
  0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
249
  0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
250
+ 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
251
  );
252
  runOnlyForDeploymentPostprocessing = 0;
253
  };
 
375
  CODE_SIGN_STYLE = Automatic;
376
  CURRENT_PROJECT_VERSION = 1;
377
  DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
378
+ DEVELOPMENT_TEAM = P8JZH34X63;
379
  ENABLE_HARDENED_RUNTIME = YES;
380
  ENABLE_PREVIEWS = YES;
381
  GENERATE_INFOPLIST_FILE = YES;
 
416
  CODE_SIGN_STYLE = Automatic;
417
  CURRENT_PROJECT_VERSION = 1;
418
  DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
419
+ DEVELOPMENT_TEAM = P8JZH34X63;
420
  ENABLE_HARDENED_RUNTIME = YES;
421
  ENABLE_PREVIEWS = YES;
422
  GENERATE_INFOPLIST_FILE = YES;
extra/bench-all.sh CHANGED
@@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then
44
  printf "\n"
45
  fi
46
 
47
- printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n"
48
- printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n"
49
 
50
  for model in "${models[@]}"; do
51
- # run once to heat-up the cache
52
- ./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null
53
-
54
  # actual run
55
  # store stderr output in a variable in order to parse it later
56
  output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
57
  ret=$?
58
 
59
  # parse the output:
60
- load_time=$(echo "$output" | grep "load time" | awk '{print $5}')
61
- encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}')
 
62
  system_info=$(echo "$output" | grep "system_info")
63
  n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
64
 
65
  # floor to milliseconds
66
- load_time=${load_time%.*}
67
- encode_time=${encode_time%.*}
 
68
 
69
  config=""
70
 
@@ -87,6 +86,6 @@ for model in "${models[@]}"; do
87
  commit=$(git rev-parse --short HEAD)
88
 
89
  if [ $ret -eq 0 ]; then
90
- printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"
91
  fi
92
  done
 
44
  printf "\n"
45
  fi
46
 
47
+ printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48
+ printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
49
 
50
  for model in "${models[@]}"; do
 
 
 
51
  # actual run
52
  # store stderr output in a variable in order to parse it later
53
  output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
54
  ret=$?
55
 
56
  # parse the output:
57
+ encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
58
+ decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
59
+ prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
60
  system_info=$(echo "$output" | grep "system_info")
61
  n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
62
 
63
  # floor to milliseconds
64
+ #encode_time=${encode_time%.*}
65
+ #decode_time=${decode_time%.*}
66
+ #prompt_time=${prompt_time%.*}
67
 
68
  config=""
69
 
 
86
  commit=$(git rev-parse --short HEAD)
87
 
88
  if [ $ret -eq 0 ]; then
89
+ printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
90
  fi
91
  done
extra/sync-ggml.sh CHANGED
@@ -1,18 +1,20 @@
1
  #!/bin/bash
2
 
3
- cp -rpv ../ggml/src/ggml.c ./ggml.c
4
- cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
5
- cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
6
- cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
7
- cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
8
- cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
9
- cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
10
- cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
11
- cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
12
- cp -rpv ../ggml/examples/common.h ./examples/common.h
13
- cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
14
- cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
15
- cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
 
 
16
 
17
  cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
18
  cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
 
1
  #!/bin/bash
2
 
3
+ cp -rpv ../ggml/src/ggml.c ./ggml.c
4
+ cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
5
+ cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
6
+ cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
7
+ cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
8
+ cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
9
+ cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
10
+ cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
11
+ cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
12
+ cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
13
+ cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
14
+ cp -rpv ../ggml/examples/common.h ./examples/common.h
15
+ cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
16
+ cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
17
+ cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
18
 
19
  cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
20
  cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
ggml-alloc.c CHANGED
@@ -6,6 +6,26 @@
6
  #include <stdlib.h>
7
  #include <string.h>
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #define UNUSED(x) (void)(x)
10
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
11
  #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
@@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
99
  }
100
  #endif
101
 
102
-
103
- static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
104
  return ggml_nbytes(tensor);
105
 
106
  UNUSED(alloc);
107
  }
108
 
 
 
 
 
 
 
 
 
 
 
109
  void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
110
- size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
 
 
 
 
111
  size = aligned_offset(NULL, size, alloc->alignment);
112
 
113
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -131,14 +164,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
131
  if (best_fit_block == -1) {
132
  // the last block is our last resort
133
  struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
 
134
  if (block->size >= size) {
135
  best_fit_block = alloc->n_free_blocks - 1;
136
- max_avail = MAX(max_avail, block->size);
137
  } else {
138
  fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
139
  __func__, size, max_avail);
140
  GGML_ASSERT(!"not enough space in the buffer");
141
- return;
142
  }
143
  }
144
  struct free_block * block = &alloc->free_blocks[best_fit_block];
@@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
173
  }
174
 
175
  // this is a very naive implementation, but for our case the number of free blocks should be very small
176
- static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
177
  void * ptr = tensor->data;
178
 
179
- if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
180
  // the tensor was not allocated in this buffer
181
  // this can happen because the graph allocator will try to free weights and other tensors from different buffers
182
  // the easiest way to deal with this is just to ignore it
183
  return;
184
  }
185
 
186
- size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
187
  size = aligned_offset(NULL, size, alloc->alignment);
188
  AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
189
 
@@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
277
  return alloc;
278
  }
279
 
280
- // address and size of the buffer when measuring
281
- // it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
282
- static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
283
- static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
286
  struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
287
 
 
 
 
 
 
288
  *alloc = (struct ggml_allocr){
289
- /*.data = */ MEASURE_BASE_ADDR,
290
- /*.size = */ MEASURE_MAX_SIZE,
291
  /*.alignment = */ alignment,
292
  /*.n_free_blocks = */ 0,
293
  /*.free_blocks = */ {{0}},
@@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
307
  }
308
 
309
  void ggml_allocr_free(struct ggml_allocr * alloc) {
 
 
 
310
  free(alloc);
311
  }
312
 
@@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
316
 
317
  //////////// compute graph allocator
318
 
319
- static bool ggml_is_view(struct ggml_tensor * t) {
320
- return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
321
- t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
322
- }
323
-
324
  static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
325
  if (a->type != b->type) {
326
  return false;
@@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
336
  return true;
337
  }
338
 
339
- static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
340
- switch (t->op) {
341
- case GGML_OP_PERMUTE:
342
- case GGML_OP_RESHAPE:
343
- case GGML_OP_TRANSPOSE:
344
- case GGML_OP_VIEW:
345
- return t->src[0];
346
- case GGML_OP_CPY:
347
- return t->src[1];
348
- default:
349
- return NULL;
350
- }
351
- }
352
-
353
- static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
354
- struct ggml_tensor * parent = t;
355
- do {
356
- parent = get_view_parent(parent);
357
- } while (ggml_is_view(parent));
358
- return parent;
359
- }
360
-
361
  static bool ggml_op_can_inplace(enum ggml_op op) {
362
  switch (op) {
363
  case GGML_OP_SCALE:
@@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
365
  case GGML_OP_DIAG_MASK_INF:
366
  case GGML_OP_ADD:
367
  case GGML_OP_ADD1:
368
- case GGML_OP_ACC:
369
  case GGML_OP_SUB:
370
  case GGML_OP_MUL:
371
  case GGML_OP_DIV:
@@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
375
  case GGML_OP_UNARY:
376
  case GGML_OP_ROPE:
377
  case GGML_OP_RMS_NORM:
378
- case GGML_OP_SET:
379
  case GGML_OP_SOFT_MAX:
380
  case GGML_OP_CONT:
381
- case GGML_OP_ADD_REL_POS:
382
  return true;
383
 
384
  default:
@@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
390
  struct hash_node * ht = alloc->hash_table;
391
  if (node->data == NULL) {
392
  if (ggml_is_view(node)) {
393
- size_t offset;
394
- switch(node->op) {
395
- case GGML_OP_VIEW:
396
- memcpy(&offset, node->op_params, sizeof(size_t));
397
- node->data = (char *) node->src[0]->data + offset;
398
- break;
399
- case GGML_OP_PERMUTE:
400
- case GGML_OP_RESHAPE:
401
- case GGML_OP_TRANSPOSE:
402
- node->data = node->src[0]->data;
403
- break;
404
- case GGML_OP_CPY:
405
- node->data = node->src[1]->data;
406
- break;
407
- default:
408
- GGML_ASSERT(!"unknown view op");
409
- break;
410
- }
411
  } else {
412
  // see if we can reuse a parent's buffer (inplace)
413
  if (ggml_op_can_inplace(node->op)) {
@@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
418
  }
419
 
420
  // if the node's data is external, then we cannot re-use it
421
- if ((char *) parent->data < (char *) alloc->data ||
422
- (char *) parent->data >= ((char *) alloc->data + alloc->size)) {
423
  AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
424
  continue;
425
  }
@@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
427
  struct hash_node * p_hn = hash_get(ht, parent);
428
  if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
429
  if (ggml_is_view(parent)) {
430
- struct ggml_tensor * view_src = get_view_source(parent);
431
  struct hash_node * view_src_hn = hash_get(ht, view_src);
432
  if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
433
  // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
@@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
453
  }
454
  }
455
 
456
- static size_t ggml_allocator_alloc_graph_tensors_n(
457
  struct ggml_allocr * alloc,
458
  struct ggml_cgraph ** graphs, int n_graphs,
459
  struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
@@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
469
  struct ggml_tensor * node = gf->nodes[i];
470
 
471
  if (ggml_is_view(node)) {
472
- struct ggml_tensor * view_src = get_view_source(node);
473
  hash_get(ht, view_src)->n_views += 1;
474
  }
475
 
@@ -531,11 +571,10 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
531
  AT_PRINTF("\n");
532
  }
533
 
534
-
535
  // update parents
536
  // update immediately if there is no parse_seq
537
  // update only at barriers if there is parse_seq
538
- if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
539
  int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
540
  int update_end = alloc->parse_seq_len ? ind : ind + 1;
541
  for (int i = update_start; i < update_end; i++) {
@@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
554
 
555
  if (p_hn->n_children == 0 && p_hn->n_views == 0) {
556
  if (ggml_is_view(parent)) {
557
- struct ggml_tensor * view_src = get_view_source(parent);
558
  struct hash_node * view_src_hn = hash_get(ht, view_src);
559
  view_src_hn->n_views -= 1;
560
  AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
561
  if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
562
- ggml_allocator_free_tensor(alloc, view_src);
563
  }
564
  }
565
  else {
566
  if (parent->data != node->data) {
567
- ggml_allocator_free_tensor(alloc, parent);
568
  }
569
  }
570
  }
@@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
581
  for (int i = 0; outputs[g][i] != NULL; i++) {
582
  struct ggml_tensor * output = outputs[g][i];
583
  AT_PRINTF("output: %s\n", output->name);
584
- ggml_allocator_free_tensor(alloc, output);
585
  }
586
  }
587
  }
@@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
590
  }
591
 
592
  size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
593
- return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
594
  }
 
6
  #include <stdlib.h>
7
  #include <string.h>
8
 
9
+ #ifdef __has_include
10
+ #if __has_include(<unistd.h>)
11
+ #include <unistd.h>
12
+ #if defined(_POSIX_MAPPED_FILES)
13
+ #include <sys/types.h>
14
+ #include <sys/mman.h>
15
+ #endif
16
+ #endif
17
+ #endif
18
+
19
+ #if defined(_WIN32)
20
+ #define WIN32_LEAN_AND_MEAN
21
+ #ifndef NOMINMAX
22
+ #define NOMINMAX
23
+ #endif
24
+ #include <windows.h>
25
+ #include <memoryapi.h>
26
+ #endif
27
+
28
+
29
  #define UNUSED(x) (void)(x)
30
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
31
  #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
 
119
  }
120
  #endif
121
 
122
+ static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
 
123
  return ggml_nbytes(tensor);
124
 
125
  UNUSED(alloc);
126
  }
127
 
128
+ // check if a tensor is allocated by this buffer
129
+ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
130
+ void * ptr = tensor->data;
131
+ return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
132
+ }
133
+
134
+ static bool ggml_is_view(struct ggml_tensor * t) {
135
+ return t->view_src != NULL;
136
+ }
137
+
138
  void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
139
+ #ifdef GGML_ALLOCATOR_DEBUG
140
+ GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
141
+ GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
142
+ #endif
143
+ size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
144
  size = aligned_offset(NULL, size, alloc->alignment);
145
 
146
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
 
164
  if (best_fit_block == -1) {
165
  // the last block is our last resort
166
  struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
167
+ max_avail = MAX(max_avail, block->size);
168
  if (block->size >= size) {
169
  best_fit_block = alloc->n_free_blocks - 1;
 
170
  } else {
171
  fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
172
  __func__, size, max_avail);
173
  GGML_ASSERT(!"not enough space in the buffer");
174
+ return;
175
  }
176
  }
177
  struct free_block * block = &alloc->free_blocks[best_fit_block];
 
206
  }
207
 
208
  // this is a very naive implementation, but for our case the number of free blocks should be very small
209
+ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
210
  void * ptr = tensor->data;
211
 
212
+ if (ggml_allocr_is_own(alloc, tensor) == false) {
213
  // the tensor was not allocated in this buffer
214
  // this can happen because the graph allocator will try to free weights and other tensors from different buffers
215
  // the easiest way to deal with this is just to ignore it
216
  return;
217
  }
218
 
219
+ size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
220
  size = aligned_offset(NULL, size, alloc->alignment);
221
  AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
222
 
 
310
  return alloc;
311
  }
312
 
313
+ // OS specific functions to allocate and free uncommitted virtual memory
314
+ static void * alloc_vmem(size_t size) {
315
+ #if defined(_WIN32)
316
+ return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
317
+ #elif defined(_POSIX_MAPPED_FILES)
318
+ void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
319
+ if (ptr == MAP_FAILED) {
320
+ return NULL;
321
+ }
322
+ return ptr;
323
+ #else
324
+ // use a fixed address for other platforms
325
+ uintptr_t base_addr = (uintptr_t)-size - 0x100;
326
+ return (void *)base_addr;
327
+ #endif
328
+ }
329
+
330
+ static void free_vmem(void * base_addr, size_t size) {
331
+ #if defined(_WIN32)
332
+ VirtualFree(base_addr, 0, MEM_RELEASE);
333
+ UNUSED(size);
334
+ #elif defined(_POSIX_MAPPED_FILES)
335
+ munmap(base_addr, size);
336
+ #else
337
+ // nothing to do
338
+ UNUSED(base_addr);
339
+ UNUSED(size);
340
+ #endif
341
+ }
342
+
343
+ // allocate uncommitted virtual memory to measure the size of the graph
344
+ static void alloc_measure_vmem(void ** base_addr, size_t * size) {
345
+ // 128GB for 64-bit, 1GB for 32-bit
346
+ *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
347
+ do {
348
+ *base_addr = alloc_vmem(*size);
349
+ if (*base_addr != NULL) {
350
+ AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
351
+ return;
352
+ }
353
+ // try again with half the size
354
+ *size /= 2;
355
+ } while (*size > 0);
356
+
357
+ GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
358
+ }
359
+
360
+ static void free_measure_vmem(void * base_addr, size_t size) {
361
+ free_vmem(base_addr, size);
362
+ }
363
 
364
  struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
365
  struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
366
 
367
+ void * base_addr;
368
+ size_t size;
369
+
370
+ alloc_measure_vmem(&base_addr, &size);
371
+
372
  *alloc = (struct ggml_allocr){
373
+ /*.data = */ base_addr,
374
+ /*.size = */ size,
375
  /*.alignment = */ alignment,
376
  /*.n_free_blocks = */ 0,
377
  /*.free_blocks = */ {{0}},
 
391
  }
392
 
393
  void ggml_allocr_free(struct ggml_allocr * alloc) {
394
+ if (alloc->measure) {
395
+ free_measure_vmem(alloc->data, alloc->size);
396
+ }
397
  free(alloc);
398
  }
399
 
 
403
 
404
  //////////// compute graph allocator
405
 
 
 
 
 
 
406
  static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
407
  if (a->type != b->type) {
408
  return false;
 
418
  return true;
419
  }
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  static bool ggml_op_can_inplace(enum ggml_op op) {
422
  switch (op) {
423
  case GGML_OP_SCALE:
 
425
  case GGML_OP_DIAG_MASK_INF:
426
  case GGML_OP_ADD:
427
  case GGML_OP_ADD1:
 
428
  case GGML_OP_SUB:
429
  case GGML_OP_MUL:
430
  case GGML_OP_DIV:
 
434
  case GGML_OP_UNARY:
435
  case GGML_OP_ROPE:
436
  case GGML_OP_RMS_NORM:
 
437
  case GGML_OP_SOFT_MAX:
438
  case GGML_OP_CONT:
 
439
  return true;
440
 
441
  default:
 
447
  struct hash_node * ht = alloc->hash_table;
448
  if (node->data == NULL) {
449
  if (ggml_is_view(node)) {
450
+ assert(node->view_src->data != NULL);
451
+ node->data = (char *)node->view_src->data + node->view_offs;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  } else {
453
  // see if we can reuse a parent's buffer (inplace)
454
  if (ggml_op_can_inplace(node->op)) {
 
459
  }
460
 
461
  // if the node's data is external, then we cannot re-use it
462
+ if (ggml_allocr_is_own(alloc, parent) == false) {
 
463
  AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
464
  continue;
465
  }
 
467
  struct hash_node * p_hn = hash_get(ht, parent);
468
  if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
469
  if (ggml_is_view(parent)) {
470
+ struct ggml_tensor * view_src = parent->view_src;
471
  struct hash_node * view_src_hn = hash_get(ht, view_src);
472
  if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
473
  // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
 
493
  }
494
  }
495
 
496
+ static size_t ggml_allocr_alloc_graph_tensors_n(
497
  struct ggml_allocr * alloc,
498
  struct ggml_cgraph ** graphs, int n_graphs,
499
  struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
 
509
  struct ggml_tensor * node = gf->nodes[i];
510
 
511
  if (ggml_is_view(node)) {
512
+ struct ggml_tensor * view_src = node->view_src;
513
  hash_get(ht, view_src)->n_views += 1;
514
  }
515
 
 
571
  AT_PRINTF("\n");
572
  }
573
 
 
574
  // update parents
575
  // update immediately if there is no parse_seq
576
  // update only at barriers if there is parse_seq
577
+ if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
578
  int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
579
  int update_end = alloc->parse_seq_len ? ind : ind + 1;
580
  for (int i = update_start; i < update_end; i++) {
 
593
 
594
  if (p_hn->n_children == 0 && p_hn->n_views == 0) {
595
  if (ggml_is_view(parent)) {
596
+ struct ggml_tensor * view_src = parent->view_src;
597
  struct hash_node * view_src_hn = hash_get(ht, view_src);
598
  view_src_hn->n_views -= 1;
599
  AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
600
  if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
601
+ ggml_allocr_free_tensor(alloc, view_src);
602
  }
603
  }
604
  else {
605
  if (parent->data != node->data) {
606
+ ggml_allocr_free_tensor(alloc, parent);
607
  }
608
  }
609
  }
 
620
  for (int i = 0; outputs[g][i] != NULL; i++) {
621
  struct ggml_tensor * output = outputs[g][i];
622
  AT_PRINTF("output: %s\n", output->name);
623
+ ggml_allocr_free_tensor(alloc, output);
624
  }
625
  }
626
  }
 
629
  }
630
 
631
  size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
632
+ return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
633
  }
ggml-metal.m CHANGED
@@ -63,7 +63,10 @@ struct ggml_metal_context {
63
  GGML_METAL_DECL_KERNEL(relu);
64
  GGML_METAL_DECL_KERNEL(gelu);
65
  GGML_METAL_DECL_KERNEL(soft_max);
 
66
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
 
 
67
  GGML_METAL_DECL_KERNEL(get_rows_f16);
68
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -77,6 +80,7 @@ struct ggml_metal_context {
77
  GGML_METAL_DECL_KERNEL(norm);
78
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
 
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
81
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
82
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -117,14 +121,17 @@ static NSString * const msl_library_source = @"see metal.metal";
117
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
118
  metal_printf("%s: allocating\n", __func__);
119
 
120
- // Show all the Metal device instances in the system
121
- NSArray * devices = MTLCopyAllDevices();
122
  id <MTLDevice> device;
123
  NSString * s;
 
 
 
 
124
  for (device in devices) {
125
  s = [device name];
126
  metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
127
  }
 
128
 
129
  // Pick and show default Metal device
130
  device = MTLCreateSystemDefaultDevice();
@@ -139,14 +146,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
139
  ctx->n_buffers = 0;
140
  ctx->concur_list_len = 0;
141
 
142
- ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
143
 
144
- #if 0
145
- // compile from source string and show compile log
146
  {
147
  NSError * error = nil;
148
 
149
- ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
 
 
 
 
 
 
 
 
150
  if (error) {
151
  metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
152
  return NULL;
@@ -161,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
161
 
162
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
163
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
164
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
165
  metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
166
 
167
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -207,7 +222,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
207
  GGML_METAL_ADD_KERNEL(relu);
208
  GGML_METAL_ADD_KERNEL(gelu);
209
  GGML_METAL_ADD_KERNEL(soft_max);
 
210
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
 
 
211
  GGML_METAL_ADD_KERNEL(get_rows_f16);
212
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
213
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -221,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
221
  GGML_METAL_ADD_KERNEL(norm);
222
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
 
224
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
225
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
226
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -247,13 +266,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
247
  #undef GGML_METAL_ADD_KERNEL
248
  }
249
 
250
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
251
  metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
 
 
252
  if (ctx->device.maxTransferRate != 0) {
253
  metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
254
  } else {
255
  metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
256
  }
 
257
 
258
  return ctx;
259
  }
@@ -273,7 +294,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
273
  GGML_METAL_DEL_KERNEL(relu);
274
  GGML_METAL_DEL_KERNEL(gelu);
275
  GGML_METAL_DEL_KERNEL(soft_max);
 
276
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
 
 
277
  GGML_METAL_DEL_KERNEL(get_rows_f16);
278
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
279
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -287,6 +311,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
287
  GGML_METAL_DEL_KERNEL(norm);
288
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
 
290
  GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291
  GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292
  GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -365,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
365
  for (int i = 0; i < ctx->n_buffers; ++i) {
366
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
367
 
 
368
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
369
  *offs = (size_t) ioffs;
370
 
@@ -454,6 +480,7 @@ bool ggml_metal_add_buffer(
454
  }
455
  }
456
 
 
457
  metal_printf(", (%8.2f / %8.2f)",
458
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
459
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@@ -463,6 +490,9 @@ bool ggml_metal_add_buffer(
463
  } else {
464
  metal_printf("\n");
465
  }
 
 
 
466
  }
467
 
468
  return true;
@@ -698,6 +728,7 @@ void ggml_metal_graph_compute(
698
  case GGML_OP_ADD:
699
  {
700
  GGML_ASSERT(ggml_is_contiguous(src0));
 
701
 
702
  // utilize float4
703
  GGML_ASSERT(ne00 % 4 == 0);
@@ -705,6 +736,7 @@ void ggml_metal_graph_compute(
705
 
706
  if (ggml_nelements(src1) == ne10) {
707
  // src1 is a row
 
708
  [encoder setComputePipelineState:ctx->pipeline_add_row];
709
  } else {
710
  [encoder setComputePipelineState:ctx->pipeline_add];
@@ -721,6 +753,7 @@ void ggml_metal_graph_compute(
721
  case GGML_OP_MUL:
722
  {
723
  GGML_ASSERT(ggml_is_contiguous(src0));
 
724
 
725
  // utilize float4
726
  GGML_ASSERT(ne00 % 4 == 0);
@@ -728,6 +761,7 @@ void ggml_metal_graph_compute(
728
 
729
  if (ggml_nelements(src1) == ne10) {
730
  // src1 is a row
 
731
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
732
  } else {
733
  [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -743,6 +777,8 @@ void ggml_metal_graph_compute(
743
  } break;
744
  case GGML_OP_SCALE:
745
  {
 
 
746
  const float scale = *(const float *) src1->data;
747
 
748
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -750,7 +786,7 @@ void ggml_metal_graph_compute(
750
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
751
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
752
 
753
- const int64_t n = ggml_nelements(dst);
754
 
755
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
756
  } break;
@@ -762,7 +798,7 @@ void ggml_metal_graph_compute(
762
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
763
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
764
 
765
- const int64_t n = ggml_nelements(dst);
766
 
767
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
768
  } break;
@@ -782,7 +818,7 @@ void ggml_metal_graph_compute(
782
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
783
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
784
 
785
- const int64_t n = ggml_nelements(dst);
786
 
787
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
788
  } break;
@@ -796,13 +832,16 @@ void ggml_metal_graph_compute(
796
  {
797
  const int nth = 32;
798
 
799
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
 
 
 
 
800
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
801
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
802
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
803
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
804
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
805
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
806
 
807
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
808
  } break;
@@ -810,14 +849,23 @@ void ggml_metal_graph_compute(
810
  {
811
  const int n_past = ((int32_t *)(dst->op_params))[0];
812
 
813
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
 
 
 
 
814
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
815
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
816
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
817
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
818
  [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
819
 
820
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
 
 
821
  } break;
822
  case GGML_OP_MUL_MAT:
823
  {
@@ -830,8 +878,8 @@ void ggml_metal_graph_compute(
830
 
831
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
832
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
833
- if (ggml_is_contiguous(src0) &&
834
- ggml_is_contiguous(src1) &&
835
  src1t == GGML_TYPE_F32 &&
836
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
837
  ne00%32 == 0 &&
@@ -856,14 +904,18 @@ void ggml_metal_graph_compute(
856
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
857
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
858
  [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
859
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
860
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
861
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
 
 
 
862
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
863
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
864
  } else {
865
  int nth0 = 32;
866
  int nth1 = 1;
 
867
 
868
  // use custom matrix x vector kernel
869
  switch (src0t) {
@@ -873,8 +925,14 @@ void ggml_metal_graph_compute(
873
  nth1 = 1;
874
  if (ne11 * ne12 < 4) {
875
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
 
 
 
 
 
876
  } else {
877
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
 
878
  }
879
  } break;
880
  case GGML_TYPE_Q4_0:
@@ -995,7 +1053,7 @@ void ggml_metal_graph_compute(
995
  else if (src0t == GGML_TYPE_Q6_K) {
996
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
997
  } else {
998
- int64_t ny = (ne11 + 3)/4;
999
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1000
  }
1001
  }
@@ -1003,6 +1061,7 @@ void ggml_metal_graph_compute(
1003
  case GGML_OP_GET_ROWS:
1004
  {
1005
  switch (src0->type) {
 
1006
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1007
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1008
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1018,9 +1077,9 @@ void ggml_metal_graph_compute(
1018
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1019
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1020
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1021
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
1022
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
1023
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
1024
 
1025
  const int64_t n = ggml_nelements(src1);
1026
 
 
63
  GGML_METAL_DECL_KERNEL(relu);
64
  GGML_METAL_DECL_KERNEL(gelu);
65
  GGML_METAL_DECL_KERNEL(soft_max);
66
+ GGML_METAL_DECL_KERNEL(soft_max_4);
67
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
68
+ GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
69
+ GGML_METAL_DECL_KERNEL(get_rows_f32);
70
  GGML_METAL_DECL_KERNEL(get_rows_f16);
71
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
72
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
 
80
  GGML_METAL_DECL_KERNEL(norm);
81
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
82
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
83
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
84
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
85
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
86
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
 
121
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
122
  metal_printf("%s: allocating\n", __func__);
123
 
 
 
124
  id <MTLDevice> device;
125
  NSString * s;
126
+
127
+ #if TARGET_OS_OSX
128
+ // Show all the Metal device instances in the system
129
+ NSArray * devices = MTLCopyAllDevices();
130
  for (device in devices) {
131
  s = [device name];
132
  metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
133
  }
134
+ #endif
135
 
136
  // Pick and show default Metal device
137
  device = MTLCreateSystemDefaultDevice();
 
146
  ctx->n_buffers = 0;
147
  ctx->concur_list_len = 0;
148
 
149
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
150
 
151
+ #ifdef GGML_SWIFT
152
+ // load the default.metallib file
153
  {
154
  NSError * error = nil;
155
 
156
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
157
+ NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
158
+ NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
159
+ NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
160
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
161
+
162
+ // Load the metallib file into a Metal library
163
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
164
+
165
  if (error) {
166
  metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
167
  return NULL;
 
176
 
177
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
178
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
179
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
180
  metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
181
 
182
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
 
222
  GGML_METAL_ADD_KERNEL(relu);
223
  GGML_METAL_ADD_KERNEL(gelu);
224
  GGML_METAL_ADD_KERNEL(soft_max);
225
+ GGML_METAL_ADD_KERNEL(soft_max_4);
226
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
227
+ GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
228
+ GGML_METAL_ADD_KERNEL(get_rows_f32);
229
  GGML_METAL_ADD_KERNEL(get_rows_f16);
230
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
231
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
 
239
  GGML_METAL_ADD_KERNEL(norm);
240
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
241
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
242
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
243
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
244
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
245
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
 
266
  #undef GGML_METAL_ADD_KERNEL
267
  }
268
 
 
269
  metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
270
+ #if TARGET_OS_OSX
271
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
272
  if (ctx->device.maxTransferRate != 0) {
273
  metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
274
  } else {
275
  metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
276
  }
277
+ #endif
278
 
279
  return ctx;
280
  }
 
294
  GGML_METAL_DEL_KERNEL(relu);
295
  GGML_METAL_DEL_KERNEL(gelu);
296
  GGML_METAL_DEL_KERNEL(soft_max);
297
+ GGML_METAL_DEL_KERNEL(soft_max_4);
298
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
299
+ GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
300
+ GGML_METAL_DEL_KERNEL(get_rows_f32);
301
  GGML_METAL_DEL_KERNEL(get_rows_f16);
302
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
303
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
 
311
  GGML_METAL_DEL_KERNEL(norm);
312
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
313
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
314
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
315
  GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
316
  GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
317
  GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
 
390
  for (int i = 0; i < ctx->n_buffers; ++i) {
391
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
392
 
393
+ //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
394
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
395
  *offs = (size_t) ioffs;
396
 
 
480
  }
481
  }
482
 
483
+ #if TARGET_OS_OSX
484
  metal_printf(", (%8.2f / %8.2f)",
485
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
486
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
490
  } else {
491
  metal_printf("\n");
492
  }
493
+ #else
494
+ metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
495
+ #endif
496
  }
497
 
498
  return true;
 
728
  case GGML_OP_ADD:
729
  {
730
  GGML_ASSERT(ggml_is_contiguous(src0));
731
+ GGML_ASSERT(ggml_is_contiguous(src1));
732
 
733
  // utilize float4
734
  GGML_ASSERT(ne00 % 4 == 0);
 
736
 
737
  if (ggml_nelements(src1) == ne10) {
738
  // src1 is a row
739
+ GGML_ASSERT(ne11 == 1);
740
  [encoder setComputePipelineState:ctx->pipeline_add_row];
741
  } else {
742
  [encoder setComputePipelineState:ctx->pipeline_add];
 
753
  case GGML_OP_MUL:
754
  {
755
  GGML_ASSERT(ggml_is_contiguous(src0));
756
+ GGML_ASSERT(ggml_is_contiguous(src1));
757
 
758
  // utilize float4
759
  GGML_ASSERT(ne00 % 4 == 0);
 
761
 
762
  if (ggml_nelements(src1) == ne10) {
763
  // src1 is a row
764
+ GGML_ASSERT(ne11 == 1);
765
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
766
  } else {
767
  [encoder setComputePipelineState:ctx->pipeline_mul];
 
777
  } break;
778
  case GGML_OP_SCALE:
779
  {
780
+ GGML_ASSERT(ggml_is_contiguous(src0));
781
+
782
  const float scale = *(const float *) src1->data;
783
 
784
  [encoder setComputePipelineState:ctx->pipeline_scale];
 
786
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
787
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
788
 
789
+ const int64_t n = ggml_nelements(dst)/4;
790
 
791
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
792
  } break;
 
798
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
799
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
800
 
801
+ const int64_t n = ggml_nelements(dst)/4;
802
 
803
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
804
  } break;
 
818
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
819
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
820
 
821
+ const int64_t n = ggml_nelements(dst)/4;
822
 
823
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
824
  } break;
 
832
  {
833
  const int nth = 32;
834
 
835
+ if (ne00%4 == 0) {
836
+ [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
837
+ } else {
838
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
839
+ }
840
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
841
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
842
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
843
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
844
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
 
845
 
846
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
847
  } break;
 
849
  {
850
  const int n_past = ((int32_t *)(dst->op_params))[0];
851
 
852
+ if (ne00%8 == 0) {
853
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
854
+ } else {
855
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
856
+ }
857
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
858
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
859
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
860
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
861
  [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
862
 
863
+ if (ne00%8 == 0) {
864
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
865
+ }
866
+ else {
867
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
868
+ }
869
  } break;
870
  case GGML_OP_MUL_MAT:
871
  {
 
878
 
879
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
880
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
881
+ if (!ggml_is_transposed(src0) &&
882
+ !ggml_is_transposed(src1) &&
883
  src1t == GGML_TYPE_F32 &&
884
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
885
  ne00%32 == 0 &&
 
904
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
905
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
906
  [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
907
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
908
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
909
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
910
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
911
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
912
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
913
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
914
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
915
  } else {
916
  int nth0 = 32;
917
  int nth1 = 1;
918
+ int nrows = 1;
919
 
920
  // use custom matrix x vector kernel
921
  switch (src0t) {
 
925
  nth1 = 1;
926
  if (ne11 * ne12 < 4) {
927
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
928
+ //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
929
+ } else if (false) {
930
+ // TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
931
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
932
+ nrows = ne11;
933
  } else {
934
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
935
+ nrows = 4;
936
  }
937
  } break;
938
  case GGML_TYPE_Q4_0:
 
1053
  else if (src0t == GGML_TYPE_Q6_K) {
1054
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1055
  } else {
1056
+ int64_t ny = (ne11 + nrows - 1)/nrows;
1057
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1058
  }
1059
  }
 
1061
  case GGML_OP_GET_ROWS:
1062
  {
1063
  switch (src0->type) {
1064
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1065
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1066
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1067
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
 
1077
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1078
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1079
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1080
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1081
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1082
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1083
 
1084
  const int64_t n = ggml_nelements(src1);
1085
 
ggml-metal.metal CHANGED
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
38
  device const float4 * src0,
39
  device const float4 * src1,
40
  device float4 * dst,
41
- constant int64_t & nb,
42
  uint tpig[[thread_position_in_grid]]) {
43
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
  }
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
  }
64
 
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
68
  constant float & scale,
69
  uint tpig[[thread_position_in_grid]]) {
70
  dst[tpig] = src0[tpig] * scale;
71
  }
72
 
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
78
  dst[tpig] = x / (1.0f + exp(-x));
79
  }
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
 
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
96
 
97
  // BEWARE !!!
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
  constant int64_t & ne00,
108
  constant int64_t & ne01,
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
  uint3 tgpig[[threadgroup_position_in_grid]],
112
  uint3 tpitg[[thread_position_in_threadgroup]],
113
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
119
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
 
121
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
125
- }
126
-
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
134
  }
135
-
136
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
137
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
138
- //if (tpitg[0] == 0) {
139
- // buf[0] = buf[0];
140
- //}
141
-
142
- //threadgroup_barrier(mem_flags::mem_threadgroup);
143
-
144
- const float max = buf[0];
145
 
146
  // parallel sum
147
- buf[tpitg[0]] = 0.0f;
148
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
149
  const float exp_psrc0 = exp(psrc0[i00] - max);
150
- buf[tpitg[0]] += exp_psrc0;
151
  // Remember the result of exp here. exp is expensive, so we really do not
152
  // whish to compute it twice.
153
  pdst[i00] = exp_psrc0;
154
  }
155
 
156
- // reduce
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
159
- if (tpitg[0] < i) {
160
- buf[tpitg[0]] += buf[tpitg[0] + i];
161
- }
162
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  }
 
164
 
165
- // broadcast - not needed, see above
166
- //// broadcast
167
- //if (tpitg[0] == 0) {
168
- // buf[0] = buf[0];
169
- //}
170
 
171
- //threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
 
 
 
172
 
173
- const float sum = buf[0];
174
 
175
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
176
- pdst[i00] /= sum;
177
  }
178
  }
179
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
192
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
193
  } else {
194
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  }
196
  }
197
 
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
616
  }
617
  }
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  kernel void kernel_alibi_f32(
620
  device const float * src0,
621
  device float * dst,
@@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
 
1126
- float yl[16];
1127
 
1128
- const uint16_t kmask1 = 0x0303;
1129
  const uint16_t kmask2 = 0x0f0f;
1130
 
1131
- const int tid = tiisg/2;
1132
- const int ix = tiisg%2;
1133
- const int ip = tid/8; // 0 or 1
1134
- const int il = tid/2 - 4*ip; // 0...3
1135
  const int ir = tid%2;
1136
  const int n = 8;
1137
  const int l0 = n*ir;
1138
 
1139
- const uint16_t m1 = 1 << (4*ip + il);
1140
- const uint16_t m2 = m1 << 8;
 
 
 
 
 
 
 
 
 
 
 
 
1141
 
1142
  const int shift = 2*il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1147
 
1148
  const uint16_t s_shift1 = 4*ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150
- const int ik = 4 + (il%2);
1151
 
1152
  const int q_offset = 32*ip + l0;
1153
  const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
 
1157
  device const float * y1 = yy + ix*QK_K + y_offset;
1158
 
1159
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160
- for (int i = ix; i < nb; i += 2) {
 
 
 
 
 
1161
 
1162
  for (int l = 0; l < 8; ++l) {
1163
- yl[l+0] = y1[l+ 0];
1164
- yl[l+8] = y1[l+16];
 
 
1165
  }
1166
 
1167
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
  for (int row = 0; row < 2; ++row) {
1173
 
1174
  const float d_all = (float)dh[0];
1175
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1176
 
1177
- float s1 = 0, s2 = 0;
 
 
 
 
 
 
 
1178
  for (int l = 0; l < n; l += 2) {
1179
- const uint16_t qs = q[l/2];
1180
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
 
 
 
 
1182
  }
1183
- float d = d_all * (s1 + 1.f/256.f * s2);
1184
- sumf1[row] += d * scales[0];
1185
- sumf2[row] += d;
 
1186
 
1187
- s1 = s2 = 0;
1188
  for (int l = 0; l < n; l += 2) {
1189
- const uint16_t qs = q[l/2+8];
1190
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
 
 
 
 
1192
  }
1193
- d = d_all * (s1 + 1.f/256.f * s2);
1194
- sumf1[row] += d * scales[1];
1195
- sumf2[row] += d;
 
1196
 
1197
  q += step;
1198
  h += step;
@@ -1201,15 +1308,17 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
 
1202
  }
1203
 
1204
- y1 += 2 * QK_K;
1205
 
1206
  }
1207
 
1208
  for (int row = 0; row < 2; ++row) {
1209
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum(sumf);
1211
- if (tiisg == 0) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
 
 
1213
  }
1214
  }
1215
  }
@@ -1564,17 +1673,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1565
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1566
 
1567
- float4 acc = {0.f, 0.f, 0.f, 0.f};
 
1568
  for (int l = 0; l < n; ++l) {
1569
  uint8_t h = qh[l];
1570
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
 
 
 
 
1574
  }
1575
  const float dall = dh[0];
1576
  const float dmin = dh[1];
1577
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
 
 
 
1578
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1579
 
1580
  q1 += step;
@@ -1747,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
1747
 
1748
  //============================= templates and their specializations =============================
1749
 
 
 
 
 
 
 
 
 
 
1750
  template <typename type4x4>
1751
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1752
  half4x4 temp = *(((device half4x4 *)src));
@@ -1758,28 +1884,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1758
  template <typename type4x4>
1759
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1760
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1761
- const half d = il ? (xb->d / 16.h) : xb->d;
1762
- const half m = il ? ( -8.h * 16.h) : -8.h;
 
1763
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1764
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1765
 
1766
  for (int i=0;i<8;i++) {
1767
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1768
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1769
  }
1770
  }
1771
 
1772
  template <typename type4x4>
1773
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1774
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1775
- const half d = il ? (xb->d / 16.h) : xb->d;
1776
- const half m = xb->m;
 
1777
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1778
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1779
 
1780
  for (int i=0;i<8;i++) {
1781
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1782
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1783
  }
1784
  }
1785
 
@@ -1815,7 +1943,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1815
 
1816
  template <typename type4x4>
1817
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1818
- const float d_all = (float)(xb->d);
1819
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1820
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1821
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,16 +1956,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1828
  ((il/4)>0 ? 12 : 3);
1829
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1830
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1831
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1832
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1833
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
 
1834
 
1835
- il = (il/2)%4;
1836
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1837
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
 
1838
 
1839
  for (int i = 0; i < 16; ++i) {
1840
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1841
  }
1842
  #else
1843
  float kcoef = il&1 ? 1.f/16.f : 1.f;
@@ -1852,26 +1982,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1852
  #endif
1853
  }
1854
 
 
 
 
 
 
1855
  template <typename type4x4>
1856
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1857
- device const uint8_t * q = xb->qs;
1858
 
1859
  #if QK_K == 256
1860
- const float d = (float)(xb->d);
1861
- const float min = (float)(xb->dmin);
1862
  short is = (il/4) * 2;
1863
  q = q + (il/4) * 32 + 16 * (il&1);
1864
- il = il%4;
1865
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1866
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1867
- const float ml = il<2 ? min * sc[1] : min * sc[3];
 
 
1868
  #else
1869
  q = q + 16 * (il&1);
1870
  device const uint8_t * s = xb->scales;
1871
  device const half2 * dh = (device const half2 *)xb->d;
1872
  const float2 d = (float2)dh[0];
1873
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1874
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
1875
  #endif
1876
  const ushort mask = il<2 ? 0x0F : 0xF0;
1877
  for (int i = 0; i < 16; ++i) {
@@ -1885,19 +2020,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1885
  device const uint8_t * qh = xb->qh;
1886
 
1887
  #if QK_K == 256
1888
- const float d = (float)(xb->d);
1889
- const float min = (float)(xb->dmin);
1890
  short is = (il/4) * 2;
1891
  q = q + 32 * (il/4) + 16 * (il&1);
1892
  qh = qh + 16 * (il&1);
1893
  uint8_t ul = 1 << (il/2);
1894
- il = il%4;
1895
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1896
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1897
- const float ml = il<2 ? min * sc[1] : min * sc[3];
 
 
1898
 
1899
- const ushort mask = il<2 ? 0x0F : 0xF0;
1900
- const float qh_val = il<2 ? 16.f : 256.f;
1901
  for (int i = 0; i < 16; ++i) {
1902
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1903
  }
@@ -1916,7 +2051,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1916
 
1917
  template <typename type4x4>
1918
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1919
- const float d_all = (float)(xb->d);
1920
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1921
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1922
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2059,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1924
  #if QK_K == 256
1925
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1926
  qh = qh + 32*(il/8) + 16*(il&1);
1927
- float sc = scales[(il%2) + 2 * ((il/2))];
1928
- il = (il/2)%4;
1929
  #else
1930
  ql = ql + 16 * (il&1);
1931
- float sc = scales[il];
1932
  #endif
 
 
 
 
 
1933
  for (int i = 0; i < 16; ++i) {
1934
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1935
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1936
- const float coef = il>1 ? 1.f/16.f : 1.f;
1937
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1938
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1939
- reg[i/4][i%4] = d_all * sc * q * coef;
1940
  }
1941
  }
1942
 
@@ -1976,22 +2113,25 @@ kernel void kernel_get_rows(
1976
  // each block_q contains 16*nl weights
1977
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1978
  kernel void kernel_mul_mm(device const uchar * src0,
1979
- device const float * src1,
1980
- device float * dst,
1981
- constant int64_t & ne00,
1982
- constant int64_t & ne02,
1983
- constant int64_t & nb01,
1984
- constant int64_t & nb02,
1985
- constant int64_t & ne12,
1986
- constant int64_t & ne0,
1987
- constant int64_t & ne1,
1988
- constant uint & gqa,
1989
- threadgroup uchar * shared_memory [[threadgroup(0)]],
1990
- uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiitg[[thread_index_in_threadgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
-
1994
- threadgroup half * sa = ((threadgroup half *)shared_memory);
 
 
 
1995
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1996
 
1997
  const uint r0 = tgpig.y;
@@ -2004,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2004
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2005
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2006
 
2007
- simdgroup_half8x8 ma[4];
2008
  simdgroup_float8x8 mb[2];
2009
  simdgroup_float8x8 c_res[8];
2010
  for (int i = 0; i < 8; i++){
@@ -2012,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
2012
  }
2013
 
2014
  short il = (tiitg % THREAD_PER_ROW);
2015
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
2016
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2017
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
2018
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
 
 
 
 
 
2019
 
2020
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2021
  //load data and store to threadgroup memory
@@ -2095,6 +2240,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2095
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2096
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2097
 
 
2098
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2099
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2100
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2105,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2105
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2106
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2107
 
2108
- typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2109
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2110
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2111
-
2112
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2113
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2114
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2115
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
 
 
 
 
 
 
 
 
 
 
 
 
 
2116
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2117
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2118
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
 
38
  device const float4 * src0,
39
  device const float4 * src1,
40
  device float4 * dst,
41
+ constant int64_t & nb,
42
  uint tpig[[thread_position_in_grid]]) {
43
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
  }
 
63
  }
64
 
65
  kernel void kernel_scale(
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
  constant float & scale,
69
  uint tpig[[thread_position_in_grid]]) {
70
  dst[tpig] = src0[tpig] * scale;
71
  }
72
 
73
  kernel void kernel_silu(
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
  uint tpig[[thread_position_in_grid]]) {
77
+ device const float4 & x = src0[tpig];
78
  dst[tpig] = x / (1.0f + exp(-x));
79
  }
80
 
 
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
 
91
  kernel void kernel_gelu(
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
  uint tpig[[thread_position_in_grid]]) {
95
+ device const float4 & x = src0[tpig];
96
 
97
  // BEWARE !!!
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
 
107
  constant int64_t & ne00,
108
  constant int64_t & ne01,
109
  constant int64_t & ne02,
 
110
  uint3 tgpig[[threadgroup_position_in_grid]],
111
  uint3 tpitg[[thread_position_in_threadgroup]],
112
  uint3 ntg[[threads_per_threadgroup]]) {
 
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
 
120
  // parallel max
121
+ float lmax = psrc0[tpitg[0]];
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
 
 
 
 
 
 
 
 
 
124
  }
125
+ const float max = simd_max(lmax);
 
 
 
 
 
 
 
 
 
126
 
127
  // parallel sum
128
+ float lsum = 0.0f;
129
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
131
+ lsum += exp_psrc0;
132
  // Remember the result of exp here. exp is expensive, so we really do not
133
  // whish to compute it twice.
134
  pdst[i00] = exp_psrc0;
135
  }
136
 
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
141
+ }
142
+ }
143
+
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
156
+
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
+
160
+ // parallel max
161
+ float4 lmax4 = psrc4[tpitg[0]];
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
166
 
167
+ const float max = simd_max(lmax);
 
 
 
 
168
 
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
175
+ }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
177
 
178
+ const float sum = simd_sum(lsum);
179
 
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
182
  }
183
  }
184
 
 
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
198
  } else {
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
227
  }
228
  }
229
 
 
648
  }
649
  }
650
 
651
+ // Assumes row size (ne00) is a multiple of 4
652
+ kernel void kernel_mul_mat_f16_f32_l4(
653
+ device const char * src0,
654
+ device const char * src1,
655
+ device float * dst,
656
+ constant int64_t & ne00,
657
+ constant int64_t & ne01,
658
+ constant int64_t & ne02,
659
+ constant uint64_t & nb00,
660
+ constant uint64_t & nb01,
661
+ constant uint64_t & nb02,
662
+ constant int64_t & ne10,
663
+ constant int64_t & ne11,
664
+ constant int64_t & ne12,
665
+ constant uint64_t & nb10,
666
+ constant uint64_t & nb11,
667
+ constant uint64_t & nb12,
668
+ constant int64_t & ne0,
669
+ constant int64_t & ne1,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint tiisg[[thread_index_in_simdgroup]]) {
672
+
673
+ const int nrows = ne11;
674
+ const int64_t r0 = tgpig.x;
675
+ const int64_t im = tgpig.z;
676
+
677
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
678
+
679
+ for (int r1 = 0; r1 < nrows; ++r1) {
680
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
681
+
682
+ float sumf = 0;
683
+ for (int i = tiisg; i < ne00/4; i += 32) {
684
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
685
+ }
686
+
687
+ float all_sum = simd_sum(sumf);
688
+ if (tiisg == 0) {
689
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
690
+ }
691
+ }
692
+ }
693
+
694
  kernel void kernel_alibi_f32(
695
  device const float * src0,
696
  device float * dst,
 
1198
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1199
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1200
 
1201
+ float yl[32];
1202
 
1203
+ const uint16_t kmask1 = 0x3030;
1204
  const uint16_t kmask2 = 0x0f0f;
1205
 
1206
+ const int tid = tiisg/4;
1207
+ const int ix = tiisg%4;
1208
+ const int ip = tid/4; // 0 or 1
1209
+ const int il = 2*((tid%4)/2); // 0 or 2
1210
  const int ir = tid%2;
1211
  const int n = 8;
1212
  const int l0 = n*ir;
1213
 
1214
+ // One would think that the Metal compiler would figure out that ip and il can only have
1215
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1216
+ // with these two tales.
1217
+ //
1218
+ // Possible masks for the high bit
1219
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1220
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1221
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1222
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1223
+
1224
+ // Possible masks for the low 2 bits
1225
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1226
+
1227
+ const ushort4 hm = mm[2*ip + il/2];
1228
 
1229
  const int shift = 2*il;
1230
+ const float v1 = il == 0 ? 4.f : 64.f;
1231
+ const float v2 = 4.f * v1;
 
 
1232
 
1233
  const uint16_t s_shift1 = 4*ip;
1234
+ const uint16_t s_shift2 = s_shift1 + il;
 
1235
 
1236
  const int q_offset = 32*ip + l0;
1237
  const int y_offset = 128*ip + 32*il + l0;
 
1240
 
1241
  device const float * y1 = yy + ix*QK_K + y_offset;
1242
 
1243
+ uint32_t scales32, aux32;
1244
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1245
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1246
+
1247
+ float sumf1[2] = {0.f};
1248
+ float sumf2[2] = {0.f};
1249
+ for (int i = ix; i < nb; i += 4) {
1250
 
1251
  for (int l = 0; l < 8; ++l) {
1252
+ yl[l+ 0] = y1[l+ 0];
1253
+ yl[l+ 8] = y1[l+16];
1254
+ yl[l+16] = y1[l+32];
1255
+ yl[l+24] = y1[l+48];
1256
  }
1257
 
1258
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
 
1263
  for (int row = 0; row < 2; ++row) {
1264
 
1265
  const float d_all = (float)dh[0];
 
1266
 
1267
+ scales16[0] = a[4];
1268
+ scales16[1] = a[5];
1269
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1270
+ scales16[0] = a[il+0];
1271
+ scales16[1] = a[il+1];
1272
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1273
+
1274
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1275
  for (int l = 0; l < n; l += 2) {
1276
+ const int32_t qs = q[l/2];
1277
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1278
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1279
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1280
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1281
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1282
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1283
  }
1284
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1285
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1286
+ sumf1[row] += d1 * (scales[0] - 32);
1287
+ sumf2[row] += d2 * (scales[2] - 32);
1288
 
1289
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1290
  for (int l = 0; l < n; l += 2) {
1291
+ const int32_t qs = q[l/2+8];
1292
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1293
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1294
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1295
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1296
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1297
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1298
  }
1299
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1300
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1301
+ sumf1[row] += d1 * (scales[1] - 32);
1302
+ sumf2[row] += d2 * (scales[3] - 32);
1303
 
1304
  q += step;
1305
  h += step;
 
1308
 
1309
  }
1310
 
1311
+ y1 += 4 * QK_K;
1312
 
1313
  }
1314
 
1315
  for (int row = 0; row < 2; ++row) {
1316
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1317
+ sumf1[row] = simd_sum(sumf);
1318
+ }
1319
+ if (tiisg == 0) {
1320
+ for (int row = 0; row < 2; ++row) {
1321
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1322
  }
1323
  }
1324
  }
 
1673
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1674
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1675
 
1676
+ float4 acc1 = {0.f};
1677
+ float4 acc2 = {0.f};
1678
  for (int l = 0; l < n; ++l) {
1679
  uint8_t h = qh[l];
1680
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1681
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1682
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1683
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1684
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1685
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1686
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1687
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1688
  }
1689
  const float dall = dh[0];
1690
  const float dmin = dh[1];
1691
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1692
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1693
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1694
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1695
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1696
 
1697
  q1 += step;
 
1864
 
1865
  //============================= templates and their specializations =============================
1866
 
1867
+ // NOTE: this is not dequantizing - we are simply fitting the template
1868
+ template <typename type4x4>
1869
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
1870
+ float4x4 temp = *(((device float4x4 *)src));
1871
+ for (int i = 0; i < 16; i++){
1872
+ reg[i/4][i%4] = temp[i/4][i%4];
1873
+ }
1874
+ }
1875
+
1876
  template <typename type4x4>
1877
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1878
  half4x4 temp = *(((device half4x4 *)src));
 
1884
  template <typename type4x4>
1885
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1886
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1887
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1888
+ const float d2 = d1 / 256.f;
1889
+ const float md = -8.h * xb->d;
1890
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1891
+ const ushort mask1 = mask0 << 8;
1892
 
1893
  for (int i=0;i<8;i++) {
1894
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1895
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1896
  }
1897
  }
1898
 
1899
  template <typename type4x4>
1900
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1901
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1902
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1903
+ const float d2 = d1 / 256.f;
1904
+ const float m = xb->m;
1905
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1906
+ const ushort mask1 = mask0 << 8;
1907
 
1908
  for (int i=0;i<8;i++) {
1909
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1910
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1911
  }
1912
  }
1913
 
 
1943
 
1944
  template <typename type4x4>
1945
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1946
+ const half d_all = xb->d;
1947
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1948
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1949
  device const int8_t * scales = (device const int8_t *)xb->scales;
 
1956
  ((il/4)>0 ? 12 : 3);
1957
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1958
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1959
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
1960
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1961
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
1962
+ const half ml = 4.h * dl;
1963
 
1964
+ il = (il/2) & 3;
1965
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1966
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1967
+ dl *= coef;
1968
 
1969
  for (int i = 0; i < 16; ++i) {
1970
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1971
  }
1972
  #else
1973
  float kcoef = il&1 ? 1.f/16.f : 1.f;
 
1982
  #endif
1983
  }
1984
 
1985
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
1986
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
1987
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
1988
+ }
1989
+
1990
  template <typename type4x4>
1991
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1992
+ device const uchar * q = xb->qs;
1993
 
1994
  #if QK_K == 256
 
 
1995
  short is = (il/4) * 2;
1996
  q = q + (il/4) * 32 + 16 * (il&1);
1997
+ il = il & 3;
1998
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
1999
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2000
+ const half min = xb->dmin;
2001
+ const half dl = d * sc[0];
2002
+ const half ml = min * sc[1];
2003
  #else
2004
  q = q + 16 * (il&1);
2005
  device const uint8_t * s = xb->scales;
2006
  device const half2 * dh = (device const half2 *)xb->d;
2007
  const float2 d = (float2)dh[0];
2008
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
2009
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
2010
  #endif
2011
  const ushort mask = il<2 ? 0x0F : 0xF0;
2012
  for (int i = 0; i < 16; ++i) {
 
2020
  device const uint8_t * qh = xb->qh;
2021
 
2022
  #if QK_K == 256
 
 
2023
  short is = (il/4) * 2;
2024
  q = q + 32 * (il/4) + 16 * (il&1);
2025
  qh = qh + 16 * (il&1);
2026
  uint8_t ul = 1 << (il/2);
2027
+ il = il & 3;
2028
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2029
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2030
+ const half min = xb->dmin;
2031
+ const half dl = d * sc[0];
2032
+ const half ml = min * sc[1];
2033
 
2034
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2035
+ const half qh_val = il<2 ? 16.h : 256.h;
2036
  for (int i = 0; i < 16; ++i) {
2037
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
2038
  }
 
2051
 
2052
  template <typename type4x4>
2053
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
2054
+ const half d_all = xb->d;
2055
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
2056
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
2057
  device const int8_t * scales = (device const int8_t *)xb->scales;
 
2059
  #if QK_K == 256
2060
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
2061
  qh = qh + 32*(il/8) + 16*(il&1);
2062
+ half sc = scales[(il%2) + 2 * ((il/2))];
2063
+ il = (il/2) & 3;
2064
  #else
2065
  ql = ql + 16 * (il&1);
2066
+ half sc = scales[il];
2067
  #endif
2068
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2069
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2070
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2071
+ const half ml = d_all * sc * 32.h;
2072
+ const half dl = d_all * sc * coef;
2073
  for (int i = 0; i < 16; ++i) {
2074
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2075
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2076
+ reg[i/4][i%4] = dl * q - ml;
 
 
 
2077
  }
2078
  }
2079
 
 
2113
  // each block_q contains 16*nl weights
2114
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2115
  kernel void kernel_mul_mm(device const uchar * src0,
2116
+ device const uchar * src1,
2117
+ device float * dst,
2118
+ constant int64_t & ne00,
2119
+ constant int64_t & ne02,
2120
+ constant int64_t & nb01,
2121
+ constant int64_t & nb02,
2122
+ constant int64_t & ne12,
2123
+ constant int64_t & nb10,
2124
+ constant int64_t & nb11,
2125
+ constant int64_t & nb12,
2126
+ constant int64_t & ne0,
2127
+ constant int64_t & ne1,
2128
+ constant uint & gqa,
2129
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
2130
+ uint3 tgpig[[threadgroup_position_in_grid]],
2131
+ uint tiitg[[thread_index_in_threadgroup]],
2132
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2133
+
2134
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
2135
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
2136
 
2137
  const uint r0 = tgpig.y;
 
2144
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2145
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2146
 
2147
+ simdgroup_half8x8 ma[4];
2148
  simdgroup_float8x8 mb[2];
2149
  simdgroup_float8x8 c_res[8];
2150
  for (int i = 0; i < 8; i++){
 
2152
  }
2153
 
2154
  short il = (tiitg % THREAD_PER_ROW);
2155
+
2156
+ uint offset0 = im/gqa*nb02;
2157
+ ushort offset1 = il/nl;
2158
+
2159
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2160
+ device const float * y = (device const float *)(src1
2161
+ + nb12 * im
2162
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2163
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2164
 
2165
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2166
  //load data and store to threadgroup memory
 
2240
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2241
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2242
 
2243
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2244
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2245
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2246
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
 
2251
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2252
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2253
 
2254
+ typedef void (mat_mm_t)(
2255
+ device const uchar * src0,
2256
+ device const uchar * src1,
2257
+ device float * dst,
2258
+ constant int64_t & ne00,
2259
+ constant int64_t & ne02,
2260
+ constant int64_t & nb01,
2261
+ constant int64_t & nb02,
2262
+ constant int64_t & ne12,
2263
+ constant int64_t & nb10,
2264
+ constant int64_t & nb11,
2265
+ constant int64_t & nb12,
2266
+ constant int64_t & ne0,
2267
+ constant int64_t & ne1,
2268
+ constant uint & gqa,
2269
+ threadgroup uchar *, uint3, uint, uint);
2270
+
2271
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2272
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2273
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2274
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2275
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2276
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2277
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
ggml.c CHANGED
@@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
4303
  }
4304
 
4305
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
4306
- size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
4307
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4308
- nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
 
 
 
 
4309
  }
 
 
 
 
 
 
 
4310
  return nbytes;
4311
  }
4312
 
@@ -18345,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
18345
  for (int i = 0; i < cgraph->n_leafs; i++) {
18346
  struct ggml_tensor * node = cgraph->leafs[i];
18347
 
18348
- GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
18349
  i,
18350
  node->ne[0], node->ne[1],
18351
- ggml_op_name(node->op));
 
18352
  }
18353
 
18354
  for (int i = 0; i < GGML_OP_COUNT; i++) {
 
4303
  }
4304
 
4305
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
4306
+ size_t nbytes;
4307
+ size_t blck_size = ggml_blck_size(tensor->type);
4308
+ if (blck_size == 1) {
4309
+ nbytes = ggml_type_size(tensor->type);
4310
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
4311
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4312
+ }
4313
  }
4314
+ else {
4315
+ nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
4316
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4317
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
4318
+ }
4319
+ }
4320
+
4321
  return nbytes;
4322
  }
4323
 
 
18356
  for (int i = 0; i < cgraph->n_leafs; i++) {
18357
  struct ggml_tensor * node = cgraph->leafs[i];
18358
 
18359
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
18360
  i,
18361
  node->ne[0], node->ne[1],
18362
+ ggml_op_name(node->op),
18363
+ ggml_get_name(node));
18364
  }
18365
 
18366
  for (int i = 0; i < GGML_OP_COUNT; i++) {
whisper.cpp CHANGED
@@ -3,11 +3,16 @@
3
  #include "coreml/whisper-encoder.h"
4
  #endif
5
 
 
 
 
 
6
  #ifdef WHISPER_USE_OPENVINO
7
  #include "openvino/whisper-openvino-encoder.h"
8
  #endif
9
 
10
  #include "ggml.h"
 
11
 
12
  #include <algorithm>
13
  #include <cassert>
@@ -24,6 +29,7 @@
24
  #include <vector>
25
  #include <regex>
26
  #include <random>
 
27
 
28
  #if defined(_MSC_VER)
29
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -115,9 +121,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
115
  //#define WHISPER_USE_FLASH_FF
116
  #define WHISPER_MAX_DECODERS 16
117
 
118
- #define WHISPER_USE_SCRATCH
119
- #define WHISPER_MAX_SCRATCH_BUFFERS 16
120
-
121
  //
122
  // ggml helpers
123
  //
@@ -133,6 +136,44 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
133
  ggml_graph_compute(graph, &plan);
134
  }
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  // available whisper models
137
  enum e_model {
138
  MODEL_UNKNOWN,
@@ -247,38 +288,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
247
 
248
  static const size_t MB = 1ull*1024*1024;
249
 
250
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
251
- { MODEL_TINY, 62ull*MB },
252
- { MODEL_BASE, 80ull*MB },
253
- { MODEL_SMALL, 120ull*MB },
254
- { MODEL_MEDIUM, 158ull*MB },
255
- { MODEL_LARGE, 198ull*MB },
256
- };
257
-
258
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
259
- { MODEL_TINY, 18ull*MB },
260
- { MODEL_BASE, 24ull*MB },
261
- { MODEL_SMALL, 36ull*MB },
262
- { MODEL_MEDIUM, 48ull*MB },
263
- { MODEL_LARGE, 60ull*MB },
264
- };
265
-
266
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
267
- { MODEL_TINY, 4ull*MB },
268
- { MODEL_BASE, 4ull*MB },
269
- { MODEL_SMALL, 6ull*MB },
270
- { MODEL_MEDIUM, 7ull*MB },
271
- { MODEL_LARGE, 9ull*MB },
272
- };
273
-
274
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
275
- { MODEL_TINY, 4ull*MB },
276
- { MODEL_BASE, 4ull*MB },
277
- { MODEL_SMALL, 6ull*MB },
278
- { MODEL_MEDIUM, 7ull*MB },
279
- { MODEL_LARGE, 9ull*MB },
280
- };
281
-
282
  static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
283
  { GGML_TYPE_F32,
284
  {
@@ -345,38 +355,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
345
  },
346
  };
347
 
348
- static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
349
- { MODEL_TINY, 3ull*MB },
350
- { MODEL_BASE, 6ull*MB },
351
- { MODEL_SMALL, 16ull*MB },
352
- { MODEL_MEDIUM, 43ull*MB },
353
- { MODEL_LARGE, 71ull*MB },
354
- };
355
-
356
- static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
357
- { MODEL_TINY, 9ull*MB },
358
- { MODEL_BASE, 18ull*MB },
359
- { MODEL_SMALL, 53ull*MB },
360
- { MODEL_MEDIUM, 141ull*MB },
361
- { MODEL_LARGE, 235ull*MB },
362
- };
363
-
364
- static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
365
- { MODEL_TINY, 30ull*MB },
366
- { MODEL_BASE, 38ull*MB },
367
- { MODEL_SMALL, 56ull*MB },
368
- { MODEL_MEDIUM, 74ull*MB },
369
- { MODEL_LARGE, 94ull*MB },
370
- };
371
-
372
- static const std::map<e_model, size_t> MEM_REQ_DECODE = {
373
- { MODEL_TINY, 3ull*MB },
374
- { MODEL_BASE, 5ull*MB },
375
- { MODEL_SMALL, 10ull*MB },
376
- { MODEL_MEDIUM, 18ull*MB },
377
- { MODEL_LARGE, 27ull*MB },
378
- };
379
-
380
  struct whisper_mel {
381
  int n_len;
382
  int n_len_org;
@@ -657,15 +635,57 @@ struct kv_buf {
657
  std::vector<uint8_t> v;
658
  };
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  struct whisper_state {
661
  int64_t t_sample_us = 0;
662
  int64_t t_encode_us = 0;
663
  int64_t t_decode_us = 0;
 
664
  int64_t t_mel_us = 0;
665
 
666
  int32_t n_sample = 0; // number of tokens sampled
667
  int32_t n_encode = 0; // number of encoder calls
668
- int32_t n_decode = 0; // number of decoder calls
 
669
  int32_t n_fail_p = 0; // number of logprob threshold failures
670
  int32_t n_fail_h = 0; // number of entropy threshold failures
671
 
@@ -679,13 +699,20 @@ struct whisper_state {
679
  // buffer for swapping KV caches between decoders during beam-search
680
  std::vector<kv_buf> kv_swap_bufs;
681
 
682
- // memory buffers used by encode / decode contexts
683
- std::vector<uint8_t> buf_compute;
684
- std::vector<uint8_t> buf_work;
685
- std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
 
 
 
 
 
 
686
 
687
- int buf_last = 0;
688
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
 
689
 
690
  // decode output (2-dimensional array: [n_tokens][n_vocab])
691
  std::vector<float> logits;
@@ -705,6 +732,10 @@ struct whisper_state {
705
  whisper_coreml_context * ctx_coreml = nullptr;
706
  #endif
707
 
 
 
 
 
708
  #ifdef WHISPER_USE_OPENVINO
709
  whisper_openvino_context * ctx_openvino = nullptr;
710
  #endif
@@ -717,37 +748,6 @@ struct whisper_state {
717
 
718
  // [EXPERIMENTAL] speed-up techniques
719
  int32_t exp_n_audio_ctx = 0; // 0 - use default
720
-
721
- void use_buf(struct ggml_context * ctx, int i) {
722
- #if defined(WHISPER_USE_SCRATCH)
723
- size_t last_size = 0;
724
-
725
- if (i == -1) {
726
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
727
- } else {
728
- auto & buf = buf_scratch[i];
729
- last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
730
- }
731
-
732
- if (buf_last >= 0) {
733
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
734
- }
735
-
736
- buf_last = i;
737
- #else
738
- (void) i;
739
- (void) ctx;
740
- #endif
741
- }
742
-
743
- size_t get_buf_max_mem(int i) const {
744
- #if defined(WHISPER_USE_SCRATCH)
745
- return buf_max_size[i];
746
- #else
747
- (void) i;
748
- return 0;
749
- #endif
750
- }
751
  };
752
 
753
  struct whisper_context {
@@ -794,10 +794,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
794
 
795
  static bool kv_cache_init(
796
  const struct whisper_hparams & hparams,
797
- const size_t mem_bytes,
798
  struct whisper_kv_cache & cache,
799
  ggml_type wtype,
800
  int n_ctx) {
 
 
 
 
 
 
 
 
801
  cache.buf.resize(mem_bytes);
802
 
803
  struct ggml_init_params params = {
@@ -813,12 +820,6 @@ static bool kv_cache_init(
813
  return false;
814
  }
815
 
816
- const int n_text_state = hparams.n_text_state;
817
- const int n_text_layer = hparams.n_text_layer;
818
-
819
- const int n_mem = n_text_layer*n_ctx;
820
- const int n_elements = n_text_state*n_mem;
821
-
822
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
823
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
824
 
@@ -961,22 +962,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
961
 
962
  // print memory requirements
963
  {
964
- // this is the total memory required to run the inference
965
- const size_t mem_required =
966
- MEM_REQ_SCRATCH0.at(model.type) +
967
- MEM_REQ_SCRATCH1.at(model.type) +
968
- MEM_REQ_SCRATCH2.at(model.type) +
969
- MEM_REQ_SCRATCH3.at(model.type) +
970
- scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
971
- scale*MEM_REQ_KV_CROSS.at(model.type) +
972
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
973
-
974
- // this is the memory required by one decoder
975
- const size_t mem_required_decoder =
976
- scale*MEM_REQ_KV_SELF.at(model.type);
977
-
978
- log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
979
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
980
  }
981
 
982
  // initialize all memory buffers
@@ -1485,49 +1473,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1485
  return true;
1486
  }
1487
 
1488
- // evaluate the encoder with the given state
1489
- //
1490
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1491
- // part of the transformer model and returns the encoded features
1492
- //
1493
- // - wctx: the model
1494
- // - wstate: the state of the encoder
1495
- // - n_threads: number of threads to use
1496
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1497
- //
1498
- static bool whisper_encode_internal(
1499
- whisper_context & wctx,
1500
- whisper_state & wstate,
1501
- const int mel_offset,
1502
- const int n_threads){
1503
 
1504
- const int64_t t_start_us = ggml_time_us();
 
 
 
 
 
 
 
 
 
 
 
 
 
1505
 
 
 
 
 
1506
  const auto & model = wctx.model;
1507
  const auto & mel_inp = wstate.mel;
1508
  const auto & hparams = model.hparams;
1509
 
1510
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1511
- const int n_state = hparams.n_audio_state;
1512
- const int n_head = hparams.n_audio_head;
1513
- const int n_layer = hparams.n_audio_layer;
1514
 
1515
  const int n_mels = hparams.n_mels;
1516
- assert(mel_inp.n_mel == n_mels);
1517
 
1518
  struct ggml_init_params params = {
1519
- /*.mem_size =*/ wstate.buf_compute.size(),
1520
- /*.mem_buffer =*/ wstate.buf_compute.data(),
1521
- /*.no_alloc =*/ false,
1522
  };
1523
 
1524
  struct ggml_context * ctx0 = ggml_init(params);
1525
 
1526
- wstate.use_buf(ctx0, 0);
 
 
1527
 
1528
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
 
 
1529
  assert(mel->type == GGML_TYPE_F32);
1530
- {
 
 
1531
  float * dst = (float *) mel->data;
1532
  memset(dst, 0, ggml_nbytes(mel));
1533
 
@@ -1541,25 +1536,11 @@ static bool whisper_encode_internal(
1541
  }
1542
  }
1543
 
1544
- struct ggml_tensor * cur;
1545
-
1546
- #ifndef WHISPER_USE_COREML
1547
- const bool use_coreml = false;
1548
- #else
1549
- const bool use_coreml = wstate.ctx_coreml != nullptr;
1550
- #endif
1551
-
1552
- #ifndef WHISPER_USE_OPENVINO
1553
- const bool use_openvino = false;
1554
- #else
1555
- const bool use_openvino = wstate.ctx_openvino != nullptr;
1556
- #endif
1557
 
1558
- if (!use_coreml && !use_openvino) {
1559
  // convolution + gelu
1560
  {
1561
- wstate.use_buf(ctx0, 1);
1562
-
1563
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1564
  cur = ggml_add(ctx0,
1565
  ggml_repeat(ctx0,
@@ -1569,8 +1550,6 @@ static bool whisper_encode_internal(
1569
 
1570
  cur = ggml_gelu(ctx0, cur);
1571
 
1572
- wstate.use_buf(ctx0, 0);
1573
-
1574
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1575
  cur = ggml_add(ctx0,
1576
  ggml_repeat(ctx0,
@@ -1581,371 +1560,431 @@ static bool whisper_encode_internal(
1581
  cur = ggml_gelu(ctx0, cur);
1582
  }
1583
 
1584
- wstate.use_buf(ctx0, 3);
 
 
 
 
1585
 
1586
- // ===================================================================
1587
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1588
- //static int iter = -1;
1589
- //const int n_iter = 1500/n_ctx;
 
 
 
1590
 
1591
- //iter = (iter + 1) % n_iter;
 
 
 
1592
 
1593
- //if (iter == 0) {
1594
- // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1595
- // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1596
- //}
1597
 
1598
- static int iter = 0;
1599
 
1600
- const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1601
- const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1602
 
1603
- struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
1604
 
1605
- cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
 
 
 
 
1606
 
1607
- // ===================================================================
 
 
 
1608
 
1609
- // original:
1610
- //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
 
 
 
1611
 
1612
- struct ggml_tensor * inpL = cur;
1613
 
1614
- for (int il = 0; il < n_layer; ++il) {
1615
- const auto & layer = model.layers_encoder[il];
1616
 
1617
- // norm
1618
- {
1619
- wstate.use_buf(ctx0, 0);
1620
 
1621
- cur = ggml_norm(ctx0, inpL, hparams.eps);
 
1622
 
1623
- // cur = ln_0_w*cur + ln_0_b
1624
- cur = ggml_add(ctx0,
1625
- ggml_mul(ctx0,
1626
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1627
- cur),
1628
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1629
- }
1630
 
1631
- // self-attention
1632
- {
1633
- wstate.use_buf(ctx0, 1);
1634
 
1635
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1636
- layer.attn_q_w,
1637
- cur);
 
1638
 
1639
- Qcur = ggml_add(ctx0,
1640
- ggml_repeat(ctx0,
1641
- layer.attn_q_b,
1642
- Qcur),
1643
- Qcur);
1644
 
1645
- //Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
 
 
1646
 
1647
- // note: no bias for Key
1648
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1649
- layer.attn_k_w,
1650
- cur);
1651
 
1652
- //Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
1653
 
1654
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1655
- layer.attn_v_w,
1656
- cur);
1657
 
1658
- Vcur = ggml_add(ctx0,
1659
- ggml_repeat(ctx0,
1660
- layer.attn_v_b,
1661
- Vcur),
1662
- Vcur);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1663
 
1664
- // ------
 
 
 
 
1665
 
1666
- wstate.use_buf(ctx0, 0);
1667
 
1668
  #ifdef WHISPER_USE_FLASH_ATTN
1669
- struct ggml_tensor * Q =
1670
- ggml_permute(ctx0,
1671
- ggml_cpy(ctx0,
1672
- Qcur,
1673
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1674
- 0, 2, 1, 3);
1675
-
1676
- struct ggml_tensor * K =
1677
- ggml_permute(ctx0,
1678
- ggml_cpy(ctx0,
1679
- Kcur,
1680
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1681
- 0, 2, 1, 3);
1682
-
1683
- struct ggml_tensor * V =
1684
- ggml_cpy(ctx0,
1685
- ggml_permute(ctx0,
1686
- ggml_reshape_3d(ctx0,
1687
- Vcur,
1688
- n_state/n_head, n_head, n_ctx),
1689
- 1, 2, 0, 3),
1690
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1691
-
1692
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1693
  #else
1694
- struct ggml_tensor * Q =
1695
- ggml_permute(ctx0,
1696
- ggml_cpy(ctx0,
1697
- Qcur,
1698
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1699
- 0, 2, 1, 3);
1700
-
1701
- struct ggml_tensor * K =
1702
- ggml_permute(ctx0,
1703
- ggml_cpy(ctx0,
1704
- Kcur,
1705
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1706
- 0, 2, 1, 3);
1707
-
1708
- // K * Q
1709
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1710
-
1711
- struct ggml_tensor * KQ_scaled =
1712
- ggml_scale_inplace(ctx0,
1713
- KQ,
1714
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1715
- );
1716
-
1717
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled);
1718
-
1719
- struct ggml_tensor * V =
1720
- ggml_cpy(ctx0,
1721
- ggml_permute(ctx0,
1722
- ggml_reshape_3d(ctx0,
1723
- Vcur,
1724
- n_state/n_head, n_head, n_ctx),
1725
- 1, 2, 0, 3),
1726
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1727
- );
1728
-
1729
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
1730
  #endif
1731
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1732
 
1733
- wstate.use_buf(ctx0, 1);
 
 
 
1734
 
1735
- cur = ggml_cpy(ctx0,
1736
- KQV_merged,
1737
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1738
- }
 
1739
 
1740
- // projection
1741
- {
1742
- wstate.use_buf(ctx0, 0);
1743
 
1744
- cur = ggml_mul_mat(ctx0,
1745
- layer.attn_ln_1_w,
1746
- cur);
1747
 
1748
- wstate.use_buf(ctx0, 1);
1749
 
 
 
 
 
 
 
 
1750
  cur = ggml_add(ctx0,
1751
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1752
- cur);
1753
  }
1754
 
1755
- wstate.use_buf(ctx0, 2);
 
 
 
 
 
 
 
 
1756
 
1757
- // add the input
1758
- cur = ggml_add(ctx0, cur, inpL);
1759
 
1760
- struct ggml_tensor * inpFF = cur;
 
1761
 
1762
- // feed-forward network
1763
- {
1764
- // norm
1765
- {
1766
- wstate.use_buf(ctx0, 0);
1767
 
1768
- cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1769
 
1770
- wstate.use_buf(ctx0, 1);
1771
 
1772
- // cur = mlp_ln_w*cur + mlp_ln_b
1773
- cur = ggml_add(ctx0,
1774
- ggml_mul(ctx0,
1775
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1776
- cur),
1777
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1778
- }
1779
 
1780
- #ifdef WHISPER_USE_FLASH_FF
1781
- wstate.use_buf(ctx0, 0);
1782
 
1783
- cur = ggml_flash_ff(ctx0,
1784
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1785
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1786
- #else
1787
- wstate.use_buf(ctx0, 0);
1788
 
1789
- // fully connected
1790
- cur = ggml_mul_mat(ctx0,
1791
- layer.mlp_0_w,
1792
- cur);
1793
 
1794
- wstate.use_buf(ctx0, 1);
 
 
1795
 
1796
- cur = ggml_add(ctx0,
1797
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1798
- cur);
1799
 
1800
- wstate.use_buf(ctx0, 0);
 
 
1801
 
1802
- // GELU activation
1803
- cur = ggml_gelu(ctx0, cur);
1804
 
1805
- wstate.use_buf(ctx0, 1);
 
 
1806
 
1807
- // projection
1808
- cur = ggml_mul_mat(ctx0,
1809
- layer.mlp_1_w,
1810
- cur);
1811
 
1812
- wstate.use_buf(ctx0, 0);
1813
 
1814
- cur = ggml_add(ctx0,
1815
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1816
- cur);
1817
- #endif
1818
- }
1819
 
1820
- wstate.use_buf(ctx0, 3);
 
 
1821
 
1822
- inpL = ggml_add(ctx0, cur, inpFF);
1823
- }
 
1824
 
1825
- cur = inpL;
1826
 
1827
- // norm
1828
- {
1829
- wstate.use_buf(ctx0, 0);
1830
 
1831
- cur = ggml_norm(ctx0, cur, hparams.eps);
 
1832
 
1833
- wstate.use_buf(ctx0, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1834
 
1835
- // cur = ln_f_g*cur + ln_f_b
1836
- cur = ggml_add(ctx0,
1837
- ggml_mul(ctx0,
1838
- ggml_repeat(ctx0, model.e_ln_w, cur),
1839
- cur),
1840
- ggml_repeat(ctx0, model.e_ln_b, cur));
1841
- }
1842
 
1843
- wstate.use_buf(ctx0, -1);
1844
 
1845
- // run the computation
1846
- {
1847
- struct ggml_cgraph gf = {};
1848
 
1849
- ggml_build_forward_expand(&gf, cur);
1850
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
1851
 
1852
- //ggml_graph_print(&gf);
 
1853
  }
1854
  }
1855
- #ifdef WHISPER_USE_COREML
1856
- else if (use_coreml) {
1857
- wstate.use_buf(ctx0, -1);
1858
 
1859
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
 
 
1860
 
1861
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1862
- }
1863
- #endif
1864
- #ifdef WHISPER_USE_OPENVINO
1865
- else if (use_openvino) {
1866
- wstate.use_buf(ctx0, -1);
1867
 
1868
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1869
 
1870
- if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
1871
- return false;
 
 
 
 
 
 
1872
  }
1873
- }
 
1874
  #endif
 
1875
 
1876
- // cur
1877
- //{
1878
- // printf("ne0 = %d\n", cur->ne[0]);
1879
- // printf("ne1 = %d\n", cur->ne[1]);
1880
- // for (int i = 0; i < 10; ++i) {
1881
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1882
- // }
1883
- // printf("... ");
1884
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1885
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1886
- // }
1887
- // printf("\n");
1888
- //}
1889
-
1890
- // pre-compute cross-attention memory
1891
  {
1892
- struct ggml_cgraph gf = {};
1893
-
1894
- // TODO: hack to disconnect the encoded features from the previous graph
1895
- cur->op = GGML_OP_NONE;
1896
- cur->src[0] = nullptr;
1897
- cur->src[1] = nullptr;
1898
-
1899
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1900
- auto& layer = model.layers_decoder[il];
1901
 
1902
- wstate.use_buf(ctx0, 0);
1903
 
1904
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
1905
- layer.cross_attn_k_w,
1906
- cur);
1907
-
1908
- Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
1909
-
1910
- wstate.use_buf(ctx0, 1);
1911
-
1912
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
1913
- layer.cross_attn_v_w,
1914
- cur);
1915
-
1916
- Vcross = ggml_add(ctx0,
1917
- ggml_repeat(ctx0,
1918
- layer.cross_attn_v_b,
1919
- Vcross),
1920
- Vcross);
1921
 
1922
- wstate.use_buf(ctx0, -1);
1923
 
1924
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1925
-
1926
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1927
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1928
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
1929
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
1930
-
1931
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1932
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1933
  }
1934
-
1935
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
1936
- //ggml_graph_print(&gf);
1937
  }
1938
 
1939
- ////////////////////////////////////////////////////////////////////////////
1940
-
1941
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1942
- // ggml_used_mem(ctx0)/1024.0/1024.0,
1943
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1944
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1945
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1946
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1947
-
1948
- ggml_free(ctx0);
1949
 
1950
  wstate.t_encode_us += ggml_time_us() - t_start_us;
1951
  wstate.n_encode++;
@@ -1953,26 +1992,13 @@ static bool whisper_encode_internal(
1953
  return true;
1954
  }
1955
 
1956
- // evaluate the decoder
1957
- //
1958
- // given text prompt + audio features -> computes the logits for the next token
1959
- //
1960
- // - model: the model
1961
- // - n_threads: number of threads to use
1962
- // - tokens: text prompt
1963
- // - n_tokens: number of tokens in the prompt
1964
- // - n_past: number of past tokens to prefix the prompt with
1965
- //
1966
- static bool whisper_decode_internal(
1967
- whisper_context & wctx,
1968
- whisper_state & wstate,
1969
- whisper_decoder & decoder,
1970
- const whisper_token * tokens,
1971
- const int n_tokens,
1972
- const int n_past,
1973
- const int n_threads) {
1974
- const int64_t t_start_us = ggml_time_us();
1975
-
1976
  const auto & model = wctx.model;
1977
  const auto & hparams = model.hparams;
1978
 
@@ -1980,10 +2006,6 @@ static bool whisper_decode_internal(
1980
 
1981
  WHISPER_ASSERT(!!kv_self.ctx);
1982
 
1983
- auto & logits_out = wstate.logits;
1984
-
1985
- const int n_vocab = hparams.n_vocab;
1986
-
1987
  const int n_ctx = hparams.n_text_ctx;
1988
  const int n_state = hparams.n_text_state;
1989
  const int n_head = hparams.n_text_head;
@@ -1995,24 +2017,39 @@ static bool whisper_decode_internal(
1995
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1996
 
1997
  struct ggml_init_params params = {
1998
- /*.mem_size =*/ wstate.buf_compute.size(),
1999
- /*.mem_buffer =*/ wstate.buf_compute.data(),
2000
- /*.no_alloc =*/ false,
2001
  };
2002
 
2003
  struct ggml_context * ctx0 = ggml_init(params);
2004
 
2005
- struct ggml_cgraph gf = {};
 
 
2006
 
2007
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2008
- memcpy(embd->data, tokens, N*ggml_element_size(embd));
 
 
 
 
2009
 
2010
  struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2011
- for (int i = 0; i < N; ++i) {
2012
- ((int32_t *) position->data)[i] = n_past + i;
 
 
 
 
2013
  }
2014
 
2015
- wstate.use_buf(ctx0, 3);
 
 
 
 
 
2016
 
2017
  // token encoding + position encoding
2018
  struct ggml_tensor * cur =
@@ -2027,16 +2064,14 @@ static bool whisper_decode_internal(
2027
 
2028
  // norm
2029
  {
2030
- wstate.use_buf(ctx0, 0);
2031
-
2032
  cur = ggml_norm(ctx0, inpL, hparams.eps);
2033
 
2034
  // cur = ln_0_w*cur + ln_0_b
2035
  cur = ggml_add(ctx0,
2036
  ggml_mul(ctx0,
2037
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
2038
- cur),
2039
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
2040
  }
2041
 
2042
  // self-attention
@@ -2046,19 +2081,17 @@ static bool whisper_decode_internal(
2046
  cur);
2047
 
2048
  Qcur = ggml_add(ctx0,
2049
- ggml_repeat(ctx0,
2050
- layer.attn_q_b,
2051
- Qcur),
2052
- Qcur);
2053
 
2054
- Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2055
 
2056
  // note: no bias for Key
2057
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
2058
  layer.attn_k_w,
2059
  cur);
2060
 
2061
- Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2062
 
2063
  // store key and value to memory
2064
  {
@@ -2067,10 +2100,8 @@ static bool whisper_decode_internal(
2067
  cur);
2068
 
2069
  Vcur = ggml_add(ctx0,
2070
- ggml_repeat(ctx0,
2071
- layer.attn_v_b,
2072
- Vcur),
2073
- Vcur);
2074
 
2075
  Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
2076
 
@@ -2079,42 +2110,32 @@ static bool whisper_decode_internal(
2079
  ( n_ctx)*ggml_element_size(kv_self.v),
2080
  (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
2081
 
2082
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
2083
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
2084
  }
2085
 
2086
  // ------
2087
 
2088
- wstate.use_buf(ctx0, 0);
2089
-
2090
  struct ggml_tensor * Q =
2091
  ggml_permute(ctx0,
2092
- ggml_cpy(ctx0,
2093
- Qcur,
2094
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2095
  0, 2, 1, 3);
2096
 
2097
  struct ggml_tensor * K =
2098
- ggml_permute(ctx0,
2099
- ggml_reshape_3d(ctx0,
2100
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
2101
- n_state/n_head, n_head, n_past + N),
2102
- 0, 2, 1, 3);
2103
-
2104
- wstate.use_buf(ctx0, 1);
2105
 
2106
  // K * Q
2107
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2108
 
2109
- //struct ggml_tensor * KQ_scaled =
2110
- // ggml_scale_inplace(ctx0,
2111
- // KQ,
2112
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2113
- // );
2114
 
2115
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ, n_past);
2116
 
2117
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
2118
 
2119
  struct ggml_tensor * V =
2120
  ggml_view_3d(ctx0, kv_self.v,
@@ -2134,36 +2155,28 @@ static bool whisper_decode_internal(
2134
 
2135
  // projection
2136
  {
2137
- wstate.use_buf(ctx0, 0);
2138
-
2139
  cur = ggml_mul_mat(ctx0,
2140
  layer.attn_ln_1_w,
2141
  cur);
2142
 
2143
- wstate.use_buf(ctx0, 1);
2144
-
2145
  cur = ggml_add(ctx0,
2146
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
2147
- cur);
2148
  }
2149
 
2150
- wstate.use_buf(ctx0, 2);
2151
-
2152
  // add the input
2153
  struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
2154
 
2155
  // norm
2156
  {
2157
- wstate.use_buf(ctx0, 0);
2158
-
2159
  cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
2160
 
2161
  // cur = ln_0_w*cur + ln_0_b
2162
  cur = ggml_add(ctx0,
2163
  ggml_mul(ctx0,
2164
- ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2165
- cur),
2166
- ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2167
  }
2168
 
2169
  // cross-attention
@@ -2173,18 +2186,18 @@ static bool whisper_decode_internal(
2173
  cur);
2174
 
2175
  Qcur = ggml_add(ctx0,
2176
- ggml_repeat(ctx0,
2177
- layer.cross_attn_q_b,
2178
- Qcur),
2179
- Qcur);
2180
 
2181
- Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2182
 
2183
  // Kcross is already scaled
2184
  struct ggml_tensor * Kcross =
2185
- ggml_reshape_3d(ctx0,
2186
- ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
2187
- n_state/n_head, n_head, M);
 
 
2188
 
2189
  //struct ggml_tensor * Vcross =
2190
  // ggml_reshape_3d(ctx0,
@@ -2207,26 +2220,22 @@ static bool whisper_decode_internal(
2207
 
2208
  struct ggml_tensor * Q =
2209
  ggml_permute(ctx0,
2210
- ggml_cpy(ctx0,
2211
- Qcur,
2212
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2213
  0, 2, 1, 3);
2214
 
2215
- struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2216
-
2217
  // K * Q
2218
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2219
 
2220
  //struct ggml_tensor * KQ_scaled =
2221
- // ggml_scale_inplace(ctx0,
2222
  // KQ,
2223
  // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2224
  // );
2225
 
2226
  // no masking for cross-attention
2227
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2228
 
2229
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ);
2230
 
2231
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2232
 
@@ -2240,21 +2249,15 @@ static bool whisper_decode_internal(
2240
 
2241
  // projection
2242
  {
2243
- wstate.use_buf(ctx0, 0);
2244
-
2245
  cur = ggml_mul_mat(ctx0,
2246
  layer.cross_attn_ln_1_w,
2247
  cur);
2248
 
2249
- wstate.use_buf(ctx0, 1);
2250
-
2251
  cur = ggml_add(ctx0,
2252
- ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2253
- cur);
2254
  }
2255
 
2256
- wstate.use_buf(ctx0, 2);
2257
-
2258
  // add the input
2259
  cur = ggml_add(ctx0, cur, inpCA);
2260
 
@@ -2264,54 +2267,38 @@ static bool whisper_decode_internal(
2264
  {
2265
  // norm
2266
  {
2267
- wstate.use_buf(ctx0, 0);
2268
-
2269
  cur = ggml_norm(ctx0, inpFF, hparams.eps);
2270
 
2271
- wstate.use_buf(ctx0, 1);
2272
-
2273
  // cur = mlp_ln_w*cur + mlp_ln_b
2274
  cur = ggml_add(ctx0,
2275
  ggml_mul(ctx0,
2276
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2277
- cur),
2278
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2279
  }
2280
 
2281
- wstate.use_buf(ctx0, 0);
2282
-
2283
  // fully connected
2284
  cur = ggml_mul_mat(ctx0,
2285
  layer.mlp_0_w,
2286
  cur);
2287
 
2288
- wstate.use_buf(ctx0, 1);
2289
-
2290
  cur = ggml_add(ctx0,
2291
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
2292
- cur);
2293
-
2294
- wstate.use_buf(ctx0, 0);
2295
 
2296
  // GELU activation
2297
  cur = ggml_gelu(ctx0, cur);
2298
 
2299
- wstate.use_buf(ctx0, 1);
2300
-
2301
  // projection
2302
  cur = ggml_mul_mat(ctx0,
2303
  layer.mlp_1_w,
2304
  cur);
2305
 
2306
- wstate.use_buf(ctx0, 0);
2307
-
2308
  cur = ggml_add(ctx0,
2309
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
2310
- cur);
2311
  }
2312
 
2313
- wstate.use_buf(ctx0, 3);
2314
-
2315
  inpL = ggml_add(ctx0, cur, inpFF);
2316
  }
2317
 
@@ -2319,21 +2306,15 @@ static bool whisper_decode_internal(
2319
 
2320
  // norm
2321
  {
2322
- wstate.use_buf(ctx0, 0);
2323
-
2324
  cur = ggml_norm(ctx0, cur, hparams.eps);
2325
 
2326
- wstate.use_buf(ctx0, 1);
2327
-
2328
  cur = ggml_add(ctx0,
2329
  ggml_mul(ctx0,
2330
- ggml_repeat(ctx0, model.d_ln_w, cur),
2331
- cur),
2332
- ggml_repeat(ctx0, model.d_ln_b, cur));
2333
  }
2334
 
2335
- wstate.use_buf(ctx0, 0);
2336
-
2337
  // compute logits only for the last token
2338
  // comment this line to compute logits for all N tokens
2339
  // might be useful in the future
@@ -2341,23 +2322,75 @@ static bool whisper_decode_internal(
2341
 
2342
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2343
 
2344
- wstate.use_buf(ctx0, -1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2345
 
2346
- // run the computation
2347
  {
2348
- ggml_build_forward_expand(&gf, logits);
2349
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2350
  }
2351
 
2352
  // extract logits for all N tokens
2353
- //logits_out.resize(N*n_vocab);
2354
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
2355
 
2356
  // extract logits only for the last token
2357
  logits_out.resize(n_vocab);
2358
  memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2359
 
2360
- if (N > 1) {
2361
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2362
  // ggml_used_mem(ctx0)/1024.0/1024.0,
2363
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2366,14 +2399,18 @@ static bool whisper_decode_internal(
2366
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2367
  }
2368
 
2369
- ggml_free(ctx0);
2370
-
2371
- wstate.t_decode_us += ggml_time_us() - t_start_us;
2372
- wstate.n_decode++;
 
 
 
2373
 
2374
  return true;
2375
  }
2376
 
 
2377
  // 500 -> 00:05.000
2378
  // 6000 -> 01:00.000
2379
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2782,9 +2819,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2782
  fill_sin_cos_table();
2783
  whisper_state * state = new whisper_state;
2784
 
2785
- const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
2786
-
2787
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2788
  log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2789
  delete state;
2790
  return nullptr;
@@ -2795,7 +2830,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2795
  log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2796
  }
2797
 
2798
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2799
  log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2800
  delete state;
2801
  return nullptr;
@@ -2816,6 +2851,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2816
  if (!state->ctx_coreml) {
2817
  log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2818
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
 
2819
  return nullptr;
2820
  #endif
2821
  } else {
@@ -2830,15 +2866,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2830
  // TAGS: WHISPER_DECODER_INIT
2831
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2832
 
2833
- state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
2834
- state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
2835
  state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
2836
- state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
2837
 
2838
- state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
2839
- state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
2840
- state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
2841
- state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2842
 
2843
  state->rng = std::mt19937(0);
2844
 
@@ -2895,7 +3027,6 @@ int whisper_ctx_init_openvino_encoder(
2895
  }
2896
 
2897
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2898
-
2899
  log("%s: loading model from '%s'\n", __func__, path_model);
2900
 
2901
  auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3048,6 +3179,13 @@ void whisper_free_state(struct whisper_state * state)
3048
  }
3049
  #endif
3050
 
 
 
 
 
 
 
 
3051
  #ifdef WHISPER_USE_OPENVINO
3052
  if (state->ctx_openvino != nullptr) {
3053
  whisper_openvino_free(state->ctx_openvino);
@@ -3055,6 +3193,11 @@ void whisper_free_state(struct whisper_state * state)
3055
  }
3056
  #endif
3057
 
 
 
 
 
 
3058
  delete state;
3059
  }
3060
  }
@@ -3475,12 +3618,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
3475
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3476
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3477
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
 
3478
 
3479
  log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3480
  log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3481
  log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3482
  log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3483
  log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
 
3484
  }
3485
  log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3486
  }
@@ -3490,6 +3635,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3490
  ctx->state->t_sample_us = 0;
3491
  ctx->state->t_encode_us = 0;
3492
  ctx->state->t_decode_us = 0;
 
 
 
 
 
3493
  }
3494
  }
3495
 
@@ -4339,6 +4489,21 @@ int whisper_full_with_state(
4339
  decoder.probs.resize (ctx->vocab.n_vocab);
4340
  decoder.logits.resize (ctx->vocab.n_vocab);
4341
  decoder.logprobs.resize(ctx->vocab.n_vocab);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4342
  }
4343
  }
4344
 
@@ -4531,8 +4696,8 @@ int whisper_full_with_state(
4531
 
4532
  decoder.kv_self.n += prompt.size();
4533
 
4534
- memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4535
- memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4536
  memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
4537
  }
4538
 
@@ -5045,6 +5210,12 @@ int whisper_full_parallel(
5045
  ctx->state->t_sample_us += states[i]->t_sample_us;
5046
  ctx->state->t_encode_us += states[i]->t_encode_us;
5047
  ctx->state->t_decode_us += states[i]->t_decode_us;
 
 
 
 
 
 
5048
 
5049
  whisper_free_state(states[i]);
5050
  }
@@ -5241,8 +5412,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5241
  // b: N*N*sizeof(float)
5242
  // c: N*N*sizeof(float)
5243
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5244
- std::vector<uint8_t> buf (3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
5245
- std::vector<uint8_t> work(1llu*N_max*N_max*sizeof(float) + 1*ggml_tensor_overhead());
5246
 
5247
  // put a bunch of random data in the buffer
5248
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
 
3
  #include "coreml/whisper-encoder.h"
4
  #endif
5
 
6
+ #ifdef GGML_USE_METAL
7
+ # include "ggml-metal.h"
8
+ #endif
9
+
10
  #ifdef WHISPER_USE_OPENVINO
11
  #include "openvino/whisper-openvino-encoder.h"
12
  #endif
13
 
14
  #include "ggml.h"
15
+ #include "ggml-alloc.h"
16
 
17
  #include <algorithm>
18
  #include <cassert>
 
29
  #include <vector>
30
  #include <regex>
31
  #include <random>
32
+ #include <functional>
33
 
34
  #if defined(_MSC_VER)
35
  #pragma warning(disable: 4244 4267) // possible loss of data
 
121
  //#define WHISPER_USE_FLASH_FF
122
  #define WHISPER_MAX_DECODERS 16
123
 
 
 
 
124
  //
125
  // ggml helpers
126
  //
 
136
  ggml_graph_compute(graph, &plan);
137
  }
138
 
139
+ // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
140
+ // the idea is to represent the original matrix multiplication:
141
+ //
142
+ // Z = X @ Y
143
+ //
144
+ // with the sum of two matrix multiplications:
145
+ //
146
+ // Z = (X_0 @ Y_0) + (X_1 @ Y_1)
147
+ //
148
+ // here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
149
+ // and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
150
+ // general-purpose kernels
151
+ //
152
+ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
153
+ // use padding only if dimension 0 is at least 8 times larger than the padding
154
+ // else we won't get much benefit from the optimization
155
+ const int n_pad_req = 8;
156
+
157
+ if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
158
+ return ggml_mul_mat(ctx, x, y);
159
+ }
160
+
161
+ struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
162
+ struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
163
+
164
+ struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
165
+ struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
166
+
167
+ return ggml_add(ctx,
168
+ ggml_mul_mat(ctx, x_0, y_0),
169
+ ggml_mul_mat(ctx, x_1, y_1));
170
+ }
171
+
172
+ // TODO: check if other platforms can benefit from this optimization
173
+ #if defined(GGML_USE_METAL)
174
+ #define ggml_mul_mat ggml_mul_mat_pad
175
+ #endif
176
+
177
  // available whisper models
178
  enum e_model {
179
  MODEL_UNKNOWN,
 
288
 
289
  static const size_t MB = 1ull*1024*1024;
290
 
291
+ // TODO: avoid using GGUF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
293
  { GGML_TYPE_F32,
294
  {
 
355
  },
356
  };
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  struct whisper_mel {
359
  int n_len;
360
  int n_len_org;
 
635
  std::vector<uint8_t> v;
636
  };
637
 
638
+ // ggml_allocr wrapper for whisper usage
639
+ struct whisper_allocr {
640
+ ggml_allocr * alloc = nullptr;
641
+
642
+ std::vector<uint8_t> meta;
643
+ std::vector<uint8_t> data;
644
+ };
645
+
646
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
647
+ return allocr.meta.size() + allocr.data.size();
648
+ }
649
+
650
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
651
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
652
+ const int tensor_alignment = 32;
653
+
654
+ auto & alloc = allocr.alloc;
655
+ auto & meta = allocr.meta;
656
+ auto & data = allocr.data;
657
+
658
+ meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
659
+
660
+ alloc = ggml_allocr_new_measure(tensor_alignment);
661
+
662
+ const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
663
+
664
+ ggml_allocr_free(alloc);
665
+
666
+ data.resize(alloc_size);
667
+
668
+ alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
669
+ }
670
+
671
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
672
+ if (allocr.alloc) {
673
+ ggml_allocr_free(allocr.alloc);
674
+ allocr.alloc = nullptr;
675
+ }
676
+ }
677
+
678
  struct whisper_state {
679
  int64_t t_sample_us = 0;
680
  int64_t t_encode_us = 0;
681
  int64_t t_decode_us = 0;
682
+ int64_t t_prompt_us = 0;
683
  int64_t t_mel_us = 0;
684
 
685
  int32_t n_sample = 0; // number of tokens sampled
686
  int32_t n_encode = 0; // number of encoder calls
687
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
688
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
689
  int32_t n_fail_p = 0; // number of logprob threshold failures
690
  int32_t n_fail_h = 0; // number of entropy threshold failures
691
 
 
699
  // buffer for swapping KV caches between decoders during beam-search
700
  std::vector<kv_buf> kv_swap_bufs;
701
 
702
+ // reusable buffer for `struct ggml_graph_plan.work_data`
703
+ std::vector<uint8_t> work_buffer;
704
+
705
+ // ggml-alloc:
706
+ // - stores meta info about the intermediate tensors into the `meta` buffers
707
+ // - stores the actual tensor data into the `data` buffers
708
+ whisper_allocr alloc_conv;
709
+ whisper_allocr alloc_encode;
710
+ whisper_allocr alloc_cross;
711
+ whisper_allocr alloc_decode;
712
 
713
+ // result of the encoder
714
+ struct ggml_tensor * embd_conv = nullptr;
715
+ struct ggml_tensor * embd_enc = nullptr;
716
 
717
  // decode output (2-dimensional array: [n_tokens][n_vocab])
718
  std::vector<float> logits;
 
732
  whisper_coreml_context * ctx_coreml = nullptr;
733
  #endif
734
 
735
+ #ifdef GGML_USE_METAL
736
+ ggml_metal_context * ctx_metal = nullptr;
737
+ #endif
738
+
739
  #ifdef WHISPER_USE_OPENVINO
740
  whisper_openvino_context * ctx_openvino = nullptr;
741
  #endif
 
748
 
749
  // [EXPERIMENTAL] speed-up techniques
750
  int32_t exp_n_audio_ctx = 0; // 0 - use default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  };
752
 
753
  struct whisper_context {
 
794
 
795
  static bool kv_cache_init(
796
  const struct whisper_hparams & hparams,
 
797
  struct whisper_kv_cache & cache,
798
  ggml_type wtype,
799
  int n_ctx) {
800
+ const int64_t n_text_state = hparams.n_text_state;
801
+ const int64_t n_text_layer = hparams.n_text_layer;
802
+
803
+ const int64_t n_mem = n_text_layer*n_ctx;
804
+ const int64_t n_elements = n_text_state*n_mem;
805
+
806
+ const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
807
+
808
  cache.buf.resize(mem_bytes);
809
 
810
  struct ggml_init_params params = {
 
820
  return false;
821
  }
822
 
 
 
 
 
 
 
823
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
824
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
825
 
 
962
 
963
  // print memory requirements
964
  {
965
+ // TODO
966
+ //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
967
+ // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  }
969
 
970
  // initialize all memory buffers
 
1473
  return true;
1474
  }
1475
 
1476
+ static bool whisper_encode_external(const whisper_state & wstate) {
1477
+ GGML_UNUSED(wstate);
 
 
 
 
 
 
 
 
 
 
 
 
 
1478
 
1479
+ #ifndef WHISPER_USE_COREML
1480
+ const bool use_coreml = false;
1481
+ #else
1482
+ const bool use_coreml = wstate.ctx_coreml != nullptr;
1483
+ #endif
1484
+
1485
+ #ifndef WHISPER_USE_OPENVINO
1486
+ const bool use_openvino = false;
1487
+ #else
1488
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
1489
+ #endif
1490
+
1491
+ return use_coreml || use_openvino;
1492
+ }
1493
 
1494
+ static struct ggml_cgraph * whisper_build_graph_conv(
1495
+ whisper_context & wctx,
1496
+ whisper_state & wstate,
1497
+ const int mel_offset) {
1498
  const auto & model = wctx.model;
1499
  const auto & mel_inp = wstate.mel;
1500
  const auto & hparams = model.hparams;
1501
 
1502
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1503
+ const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
 
 
1504
 
1505
  const int n_mels = hparams.n_mels;
 
1506
 
1507
  struct ggml_init_params params = {
1508
+ /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1509
+ /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1510
+ /*.no_alloc =*/ true,
1511
  };
1512
 
1513
  struct ggml_context * ctx0 = ggml_init(params);
1514
 
1515
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
1516
+
1517
+ ggml_allocr * alloc = wstate.alloc_conv.alloc;
1518
 
1519
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1520
+ ggml_allocr_alloc(alloc, mel);
1521
+
1522
  assert(mel->type == GGML_TYPE_F32);
1523
+ if (!ggml_allocr_is_measure(alloc)) {
1524
+ assert(mel_inp.n_mel == n_mels);
1525
+
1526
  float * dst = (float *) mel->data;
1527
  memset(dst, 0, ggml_nbytes(mel));
1528
 
 
1536
  }
1537
  }
1538
 
1539
+ struct ggml_tensor * cur = nullptr;
 
 
 
 
 
 
 
 
 
 
 
 
1540
 
1541
+ if (!whisper_encode_external(wstate)) {
1542
  // convolution + gelu
1543
  {
 
 
1544
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1545
  cur = ggml_add(ctx0,
1546
  ggml_repeat(ctx0,
 
1550
 
1551
  cur = ggml_gelu(ctx0, cur);
1552
 
 
 
1553
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1554
  cur = ggml_add(ctx0,
1555
  ggml_repeat(ctx0,
 
1560
  cur = ggml_gelu(ctx0, cur);
1561
  }
1562
 
1563
+ wstate.embd_conv = cur;
1564
+ } else {
1565
+ #ifdef WHISPER_USE_COREML
1566
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1567
+ ggml_allocr_alloc(alloc, cur);
1568
 
1569
+ if (!ggml_allocr_is_measure(alloc)) {
1570
+ whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1571
+ }
1572
+ #endif
1573
+ #ifdef WHISPER_USE_OPENVINO
1574
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1575
+ ggml_allocr_alloc(alloc, cur);
1576
 
1577
+ if (!ggml_allocr_is_measure(alloc)) {
1578
+ whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1579
+ }
1580
+ #endif
1581
 
1582
+ wstate.embd_enc = cur;
1583
+ }
 
 
1584
 
1585
+ ggml_build_forward_expand(gf, cur);
1586
 
1587
+ ggml_free(ctx0);
 
1588
 
1589
+ return gf;
1590
+ }
1591
 
1592
+ static struct ggml_cgraph * whisper_build_graph_encoder(
1593
+ whisper_context & wctx,
1594
+ whisper_state & wstate) {
1595
+ const auto & model = wctx.model;
1596
+ const auto & hparams = model.hparams;
1597
 
1598
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1599
+ const int n_state = hparams.n_audio_state;
1600
+ const int n_head = hparams.n_audio_head;
1601
+ const int n_layer = hparams.n_audio_layer;
1602
 
1603
+ struct ggml_init_params params = {
1604
+ /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1605
+ /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1606
+ /*.no_alloc =*/ true,
1607
+ };
1608
 
1609
+ struct ggml_context * ctx0 = ggml_init(params);
1610
 
1611
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
 
1612
 
1613
+ ggml_allocr * alloc = wstate.alloc_encode.alloc;
 
 
1614
 
1615
+ struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1616
+ ggml_allocr_alloc(alloc, KQscale);
1617
 
1618
+ if (!ggml_allocr_is_measure(alloc)) {
1619
+ ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
1620
+ }
 
 
 
 
1621
 
1622
+ struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
 
 
1623
 
1624
+ // ===================================================================
1625
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
1626
+ //static int iter = -1;
1627
+ //const int n_iter = 1500/n_ctx;
1628
 
1629
+ //iter = (iter + 1) % n_iter;
 
 
 
 
1630
 
1631
+ //if (iter == 0) {
1632
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1633
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1634
+ //}
1635
 
1636
+ static int iter = 0;
 
 
 
1637
 
1638
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1639
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1640
 
1641
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
 
1642
 
1643
+ cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
1644
+
1645
+ // ===================================================================
1646
+
1647
+ // original:
1648
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1649
+
1650
+ struct ggml_tensor * inpL = cur;
1651
+
1652
+ for (int il = 0; il < n_layer; ++il) {
1653
+ const auto & layer = model.layers_encoder[il];
1654
+
1655
+ // norm
1656
+ {
1657
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
1658
+
1659
+ // cur = ln_0_w*cur + ln_0_b
1660
+ cur = ggml_add(ctx0,
1661
+ ggml_mul(ctx0, cur, layer.attn_ln_0_w),
1662
+ layer.attn_ln_0_b);
1663
+ }
1664
+
1665
+ // self-attention
1666
+ {
1667
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1668
+ layer.attn_q_w,
1669
+ cur);
1670
+
1671
+ Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1672
+
1673
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1674
+
1675
+ // note: no bias for Key
1676
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1677
+ layer.attn_k_w,
1678
+ cur);
1679
+
1680
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1681
 
1682
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1683
+ layer.attn_v_w,
1684
+ cur);
1685
+
1686
+ Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
1687
 
1688
+ // ------
1689
 
1690
  #ifdef WHISPER_USE_FLASH_ATTN
1691
+ struct ggml_tensor * Q =
1692
+ ggml_permute(ctx0,
1693
+ ggml_cpy(ctx0,
1694
+ Qcur,
1695
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1696
+ 0, 2, 1, 3);
1697
+
1698
+ struct ggml_tensor * K =
1699
+ ggml_permute(ctx0,
1700
+ ggml_cpy(ctx0,
1701
+ Kcur,
1702
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1703
+ 0, 2, 1, 3);
1704
+
1705
+ struct ggml_tensor * V =
1706
+ ggml_cpy(ctx0,
1707
+ ggml_permute(ctx0,
1708
+ ggml_reshape_3d(ctx0,
1709
+ Vcur,
1710
+ n_state/n_head, n_head, n_ctx),
1711
+ 1, 2, 0, 3),
1712
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1713
+
1714
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1715
  #else
1716
+ struct ggml_tensor * Q =
1717
+ ggml_permute(ctx0,
1718
+ ggml_cpy(ctx0,
1719
+ Qcur,
1720
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1721
+ 0, 2, 1, 3);
1722
+
1723
+ struct ggml_tensor * K =
1724
+ ggml_permute(ctx0,
1725
+ ggml_cpy(ctx0,
1726
+ Kcur,
1727
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1728
+ 0, 2, 1, 3);
1729
+
1730
+ // K * Q
1731
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1732
+
1733
+ struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
1734
+
1735
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1736
+
1737
+ struct ggml_tensor * V =
1738
+ ggml_cpy(ctx0,
1739
+ ggml_permute(ctx0,
1740
+ ggml_reshape_3d(ctx0,
1741
+ Vcur,
1742
+ n_state/n_head, n_head, n_ctx),
1743
+ 1, 2, 0, 3),
1744
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1745
+ );
1746
+
1747
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
 
 
 
1748
  #endif
1749
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1750
 
1751
+ cur = ggml_cpy(ctx0,
1752
+ KQV_merged,
1753
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1754
+ }
1755
 
1756
+ // projection
1757
+ {
1758
+ cur = ggml_mul_mat(ctx0,
1759
+ layer.attn_ln_1_w,
1760
+ cur);
1761
 
1762
+ cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
1763
+ }
 
1764
 
1765
+ // add the input
1766
+ cur = ggml_add(ctx0, cur, inpL);
 
1767
 
1768
+ struct ggml_tensor * inpFF = cur;
1769
 
1770
+ // feed-forward network
1771
+ {
1772
+ // norm
1773
+ {
1774
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
1775
+
1776
+ // cur = mlp_ln_w*cur + mlp_ln_b
1777
  cur = ggml_add(ctx0,
1778
+ ggml_mul(ctx0, cur, layer.mlp_ln_w),
1779
+ layer.mlp_ln_b);
1780
  }
1781
 
1782
+ #ifdef WHISPER_USE_FLASH_FF
1783
+ cur = ggml_flash_ff(ctx0,
1784
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1785
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1786
+ #else
1787
+ // fully connected
1788
+ cur = ggml_mul_mat(ctx0,
1789
+ layer.mlp_0_w,
1790
+ cur);
1791
 
1792
+ cur = ggml_add(ctx0, cur, layer.mlp_0_b);
 
1793
 
1794
+ // GELU activation
1795
+ cur = ggml_gelu(ctx0, cur);
1796
 
1797
+ // projection
1798
+ cur = ggml_mul_mat(ctx0,
1799
+ layer.mlp_1_w,
1800
+ cur);
 
1801
 
1802
+ cur = ggml_add(ctx0, cur, layer.mlp_1_b);
1803
+ #endif
1804
+ }
1805
+
1806
+ inpL = ggml_add(ctx0, cur, inpFF);
1807
+ }
1808
+
1809
+ cur = inpL;
1810
+
1811
+ // norm
1812
+ {
1813
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1814
+
1815
+ // cur = ln_f_g*cur + ln_f_b
1816
+ cur = ggml_add(ctx0,
1817
+ ggml_mul(ctx0, cur, model.e_ln_w),
1818
+ model.e_ln_b);
1819
+ }
1820
+
1821
+ ggml_build_forward_expand(gf, cur);
1822
+
1823
+ wstate.embd_enc = cur;
1824
+
1825
+ //ggml_graph_print(gf);
1826
+
1827
+ ////////////////////////////////////////////////////////////////////////////
1828
+
1829
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1830
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
1831
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1832
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1833
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1834
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1835
+
1836
+ ggml_free(ctx0);
1837
+
1838
+ return gf;
1839
+ }
1840
+
1841
+ // pre-compute cross-attention memory
1842
+ static struct ggml_cgraph * whisper_build_graph_cross(
1843
+ whisper_context & wctx,
1844
+ whisper_state & wstate) {
1845
+ const auto & model = wctx.model;
1846
+ const auto & hparams = model.hparams;
1847
+
1848
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1849
+ const int n_state = hparams.n_audio_state;
1850
+ const int n_head = hparams.n_audio_head;
1851
+
1852
+ struct ggml_init_params params = {
1853
+ /*.mem_size =*/ wstate.alloc_cross.meta.size(),
1854
+ /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
1855
+ /*.no_alloc =*/ true,
1856
+ };
1857
 
1858
+ struct ggml_context * ctx0 = ggml_init(params);
1859
 
1860
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
 
 
 
 
 
 
1861
 
1862
+ ggml_allocr * alloc = wstate.alloc_cross.alloc;
 
1863
 
1864
+ struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
 
 
 
1865
 
1866
+ struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1867
+ ggml_allocr_alloc(alloc, Kscale);
 
 
1868
 
1869
+ if (!ggml_allocr_is_measure(alloc)) {
1870
+ ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
1871
+ }
1872
 
1873
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1874
+ auto & layer = model.layers_decoder[il];
 
1875
 
1876
+ struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
1877
+ layer.cross_attn_k_w,
1878
+ cur);
1879
 
1880
+ Kcross = ggml_scale(ctx0, Kcross, Kscale);
 
1881
 
1882
+ struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
1883
+ layer.cross_attn_v_w,
1884
+ cur);
1885
 
1886
+ Vcross = ggml_add(ctx0,
1887
+ Vcross,
1888
+ layer.cross_attn_v_b);
 
1889
 
1890
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1891
 
1892
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
1893
+ n_state*n_ctx,
1894
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
 
 
1895
 
1896
+ struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1897
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
1898
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
1899
 
1900
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
1901
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
1902
+ }
1903
 
1904
+ //ggml_graph_print(gf);
1905
 
1906
+ ggml_free(ctx0);
 
 
1907
 
1908
+ return gf;
1909
+ }
1910
 
1911
+ // evaluate the encoder with the given state
1912
+ //
1913
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1914
+ // part of the transformer model and returns the encoded features
1915
+ //
1916
+ // - wctx: the model
1917
+ // - wstate: the state of the encoder
1918
+ // - n_threads: number of threads to use
1919
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1920
+ //
1921
+ static bool whisper_encode_internal(
1922
+ whisper_context & wctx,
1923
+ whisper_state & wstate,
1924
+ const int mel_offset,
1925
+ const int n_threads) {
1926
+ const int64_t t_start_us = ggml_time_us();
1927
 
1928
+ // conv
1929
+ {
1930
+ auto & alloc = wstate.alloc_conv.alloc;
 
 
 
 
1931
 
1932
+ ggml_allocr_reset(alloc);
1933
 
1934
+ ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
 
 
1935
 
1936
+ ggml_allocr_alloc_graph(alloc, gf);
 
1937
 
1938
+ if (!whisper_encode_external(wstate)) {
1939
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1940
  }
1941
  }
 
 
 
1942
 
1943
+ // encoder
1944
+ if (!whisper_encode_external(wstate)) {
1945
+ auto & alloc = wstate.alloc_encode.alloc;
1946
 
1947
+ ggml_allocr_reset(alloc);
 
 
 
 
 
1948
 
1949
+ ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
1950
 
1951
+ ggml_allocr_alloc_graph(alloc, gf);
1952
+
1953
+ #ifdef GGML_USE_METAL
1954
+ if (wstate.ctx_metal) {
1955
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1956
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
1957
+ } else {
1958
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1959
  }
1960
+ #else
1961
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1962
  #endif
1963
+ }
1964
 
1965
+ // cross
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1966
  {
1967
+ auto & alloc = wstate.alloc_cross.alloc;
 
 
 
 
 
 
 
 
1968
 
1969
+ ggml_allocr_reset(alloc);
1970
 
1971
+ ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1972
 
1973
+ ggml_allocr_alloc_graph(alloc, gf);
1974
 
1975
+ #ifdef GGML_USE_METAL
1976
+ if (wstate.ctx_metal) {
1977
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1978
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
1979
+ } else {
1980
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
 
 
 
1981
  }
1982
+ #else
1983
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1984
+ #endif
1985
  }
1986
 
1987
+ // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
 
 
 
 
 
 
 
 
1988
 
1989
  wstate.t_encode_us += ggml_time_us() - t_start_us;
1990
  wstate.n_encode++;
 
1992
  return true;
1993
  }
1994
 
1995
+ static struct ggml_cgraph * whisper_build_graph_decoder(
1996
+ whisper_context & wctx,
1997
+ whisper_state & wstate,
1998
+ whisper_decoder & decoder,
1999
+ const whisper_token * tokens,
2000
+ int n_tokens,
2001
+ int n_past) {
 
 
 
 
 
 
 
 
 
 
 
 
 
2002
  const auto & model = wctx.model;
2003
  const auto & hparams = model.hparams;
2004
 
 
2006
 
2007
  WHISPER_ASSERT(!!kv_self.ctx);
2008
 
 
 
 
 
2009
  const int n_ctx = hparams.n_text_ctx;
2010
  const int n_state = hparams.n_text_state;
2011
  const int n_head = hparams.n_text_head;
 
2017
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
2018
 
2019
  struct ggml_init_params params = {
2020
+ /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2021
+ /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2022
+ /*.no_alloc =*/ true,
2023
  };
2024
 
2025
  struct ggml_context * ctx0 = ggml_init(params);
2026
 
2027
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
2028
+
2029
+ ggml_allocr * alloc = wstate.alloc_decode.alloc;
2030
 
2031
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2032
+ ggml_allocr_alloc(alloc, embd);
2033
+
2034
+ if (!ggml_allocr_is_measure(alloc)) {
2035
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
2036
+ }
2037
 
2038
  struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
2039
+ ggml_allocr_alloc(alloc, position);
2040
+
2041
+ if (!ggml_allocr_is_measure(alloc)) {
2042
+ for (int i = 0; i < N; ++i) {
2043
+ ((int32_t *) position->data)[i] = n_past + i;
2044
+ }
2045
  }
2046
 
2047
+ struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
2048
+ ggml_allocr_alloc(alloc, KQscale);
2049
+
2050
+ if (!ggml_allocr_is_measure(alloc)) {
2051
+ ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
2052
+ }
2053
 
2054
  // token encoding + position encoding
2055
  struct ggml_tensor * cur =
 
2064
 
2065
  // norm
2066
  {
 
 
2067
  cur = ggml_norm(ctx0, inpL, hparams.eps);
2068
 
2069
  // cur = ln_0_w*cur + ln_0_b
2070
  cur = ggml_add(ctx0,
2071
  ggml_mul(ctx0,
2072
+ cur,
2073
+ layer.attn_ln_0_w),
2074
+ layer.attn_ln_0_b);
2075
  }
2076
 
2077
  // self-attention
 
2081
  cur);
2082
 
2083
  Qcur = ggml_add(ctx0,
2084
+ Qcur,
2085
+ layer.attn_q_b);
 
 
2086
 
2087
+ Qcur = ggml_scale(ctx0, Qcur, KQscale);
2088
 
2089
  // note: no bias for Key
2090
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
2091
  layer.attn_k_w,
2092
  cur);
2093
 
2094
+ Kcur = ggml_scale(ctx0, Kcur, KQscale);
2095
 
2096
  // store key and value to memory
2097
  {
 
2100
  cur);
2101
 
2102
  Vcur = ggml_add(ctx0,
2103
+ Vcur,
2104
+ layer.attn_v_b);
 
 
2105
 
2106
  Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
2107
 
 
2110
  ( n_ctx)*ggml_element_size(kv_self.v),
2111
  (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
2112
 
2113
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2114
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
2115
  }
2116
 
2117
  // ------
2118
 
 
 
2119
  struct ggml_tensor * Q =
2120
  ggml_permute(ctx0,
2121
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
 
 
2122
  0, 2, 1, 3);
2123
 
2124
  struct ggml_tensor * K =
2125
+ ggml_view_3d(ctx0, kv_self.k,
2126
+ n_state/n_head, n_past + N, n_head,
2127
+ ggml_element_size(kv_self.k)*n_state,
2128
+ ggml_element_size(kv_self.k)*n_state/n_head,
2129
+ ggml_element_size(kv_self.k)*n_state*n_ctx*il);
 
 
2130
 
2131
  // K * Q
2132
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2133
 
2134
+ //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
 
 
 
 
2135
 
2136
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
2137
 
2138
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2139
 
2140
  struct ggml_tensor * V =
2141
  ggml_view_3d(ctx0, kv_self.v,
 
2155
 
2156
  // projection
2157
  {
 
 
2158
  cur = ggml_mul_mat(ctx0,
2159
  layer.attn_ln_1_w,
2160
  cur);
2161
 
 
 
2162
  cur = ggml_add(ctx0,
2163
+ cur,
2164
+ layer.attn_ln_1_b);
2165
  }
2166
 
 
 
2167
  // add the input
2168
  struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
2169
 
2170
  // norm
2171
  {
 
 
2172
  cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
2173
 
2174
  // cur = ln_0_w*cur + ln_0_b
2175
  cur = ggml_add(ctx0,
2176
  ggml_mul(ctx0,
2177
+ cur,
2178
+ layer.cross_attn_ln_0_w),
2179
+ layer.cross_attn_ln_0_b);
2180
  }
2181
 
2182
  // cross-attention
 
2186
  cur);
2187
 
2188
  Qcur = ggml_add(ctx0,
2189
+ Qcur,
2190
+ layer.cross_attn_q_b);
 
 
2191
 
2192
+ Qcur = ggml_scale(ctx0, Qcur, KQscale);
2193
 
2194
  // Kcross is already scaled
2195
  struct ggml_tensor * Kcross =
2196
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2197
+ n_state/n_head, M, n_head,
2198
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2199
+ ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2200
+ ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2201
 
2202
  //struct ggml_tensor * Vcross =
2203
  // ggml_reshape_3d(ctx0,
 
2220
 
2221
  struct ggml_tensor * Q =
2222
  ggml_permute(ctx0,
2223
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
 
 
2224
  0, 2, 1, 3);
2225
 
 
 
2226
  // K * Q
2227
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2228
 
2229
  //struct ggml_tensor * KQ_scaled =
2230
+ // ggml_scale(ctx0,
2231
  // KQ,
2232
  // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2233
  // );
2234
 
2235
  // no masking for cross-attention
2236
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2237
 
2238
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2239
 
2240
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2241
 
 
2249
 
2250
  // projection
2251
  {
 
 
2252
  cur = ggml_mul_mat(ctx0,
2253
  layer.cross_attn_ln_1_w,
2254
  cur);
2255
 
 
 
2256
  cur = ggml_add(ctx0,
2257
+ cur,
2258
+ layer.cross_attn_ln_1_b);
2259
  }
2260
 
 
 
2261
  // add the input
2262
  cur = ggml_add(ctx0, cur, inpCA);
2263
 
 
2267
  {
2268
  // norm
2269
  {
 
 
2270
  cur = ggml_norm(ctx0, inpFF, hparams.eps);
2271
 
 
 
2272
  // cur = mlp_ln_w*cur + mlp_ln_b
2273
  cur = ggml_add(ctx0,
2274
  ggml_mul(ctx0,
2275
+ cur,
2276
+ layer.mlp_ln_w),
2277
+ layer.mlp_ln_b);
2278
  }
2279
 
 
 
2280
  // fully connected
2281
  cur = ggml_mul_mat(ctx0,
2282
  layer.mlp_0_w,
2283
  cur);
2284
 
 
 
2285
  cur = ggml_add(ctx0,
2286
+ cur,
2287
+ layer.mlp_0_b);
 
 
2288
 
2289
  // GELU activation
2290
  cur = ggml_gelu(ctx0, cur);
2291
 
 
 
2292
  // projection
2293
  cur = ggml_mul_mat(ctx0,
2294
  layer.mlp_1_w,
2295
  cur);
2296
 
 
 
2297
  cur = ggml_add(ctx0,
2298
+ cur,
2299
+ layer.mlp_1_b);
2300
  }
2301
 
 
 
2302
  inpL = ggml_add(ctx0, cur, inpFF);
2303
  }
2304
 
 
2306
 
2307
  // norm
2308
  {
 
 
2309
  cur = ggml_norm(ctx0, cur, hparams.eps);
2310
 
 
 
2311
  cur = ggml_add(ctx0,
2312
  ggml_mul(ctx0,
2313
+ cur,
2314
+ model.d_ln_w),
2315
+ model.d_ln_b);
2316
  }
2317
 
 
 
2318
  // compute logits only for the last token
2319
  // comment this line to compute logits for all N tokens
2320
  // might be useful in the future
 
2322
 
2323
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2324
 
2325
+ ggml_build_forward_expand(gf, logits);
2326
+
2327
+ ggml_free(ctx0);
2328
+
2329
+ return gf;
2330
+ }
2331
+
2332
+ // evaluate the decoder
2333
+ //
2334
+ // given text prompt + audio features -> computes the logits for the next token
2335
+ //
2336
+ // - model: the model
2337
+ // - n_threads: number of threads to use
2338
+ // - tokens: text prompt
2339
+ // - n_tokens: number of tokens in the prompt
2340
+ // - n_past: number of past tokens to prefix the prompt with
2341
+ //
2342
+ static bool whisper_decode_internal(
2343
+ whisper_context & wctx,
2344
+ whisper_state & wstate,
2345
+ whisper_decoder & decoder,
2346
+ const whisper_token * tokens,
2347
+ const int n_tokens,
2348
+ const int n_past,
2349
+ const int n_threads) {
2350
+ const int64_t t_start_us = ggml_time_us();
2351
+
2352
+ const auto & model = wctx.model;
2353
+ const auto & hparams = model.hparams;
2354
+
2355
+ const int n_vocab = hparams.n_vocab;
2356
+
2357
+ auto & logits_out = wstate.logits;
2358
+
2359
+ struct ggml_tensor * logits;
2360
 
2361
+ // decoder
2362
  {
2363
+ auto & alloc = wstate.alloc_decode.alloc;
2364
+
2365
+ ggml_allocr_reset(alloc);
2366
+
2367
+ ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2368
+
2369
+ ggml_allocr_alloc_graph(alloc, gf);
2370
+
2371
+ logits = gf->nodes[gf->n_nodes - 1];
2372
+
2373
+ #ifdef GGML_USE_METAL
2374
+ if (wstate.ctx_metal) {
2375
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2376
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
2377
+ } else {
2378
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2379
+ }
2380
+ #else
2381
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2382
+ #endif
2383
  }
2384
 
2385
  // extract logits for all N tokens
2386
+ //logits_out.resize(n_tokens*n_vocab);
2387
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2388
 
2389
  // extract logits only for the last token
2390
  logits_out.resize(n_vocab);
2391
  memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2392
 
2393
+ if (n_tokens > 1) {
2394
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2395
  // ggml_used_mem(ctx0)/1024.0/1024.0,
2396
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
 
2399
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2400
  }
2401
 
2402
+ if (n_tokens == 1) {
2403
+ wstate.t_decode_us += ggml_time_us() - t_start_us;
2404
+ wstate.n_decode++;
2405
+ } else {
2406
+ wstate.t_prompt_us += ggml_time_us() - t_start_us;
2407
+ wstate.n_prompt++;
2408
+ }
2409
 
2410
  return true;
2411
  }
2412
 
2413
+
2414
  // 500 -> 00:05.000
2415
  // 6000 -> 01:00.000
2416
  static std::string to_timestamp(int64_t t, bool comma = false) {
 
2819
  fill_sin_cos_table();
2820
  whisper_state * state = new whisper_state;
2821
 
2822
+ if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
 
 
2823
  log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2824
  delete state;
2825
  return nullptr;
 
2830
  log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2831
  }
2832
 
2833
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2834
  log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2835
  delete state;
2836
  return nullptr;
 
2851
  if (!state->ctx_coreml) {
2852
  log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2853
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2854
+ delete state;
2855
  return nullptr;
2856
  #endif
2857
  } else {
 
2866
  // TAGS: WHISPER_DECODER_INIT
2867
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2868
 
2869
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2870
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2871
  state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
 
2872
 
2873
+ // conv allocator
2874
+ {
2875
+ whisper_allocr_graph_init(state->alloc_conv,
2876
+ [&]() {
2877
+ return whisper_build_graph_conv(*ctx, *state, 0);
2878
+ });
2879
+
2880
+ log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
2881
+ }
2882
+
2883
+ // encoder allocator
2884
+ if (!whisper_encode_external(*state)) {
2885
+ whisper_allocr_graph_init(state->alloc_encode,
2886
+ [&]() {
2887
+ return whisper_build_graph_encoder(*ctx, *state);
2888
+ });
2889
+
2890
+ log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
2891
+ }
2892
+
2893
+ // cross allocator
2894
+ {
2895
+ whisper_allocr_graph_init(state->alloc_cross,
2896
+ [&]() {
2897
+ return whisper_build_graph_cross(*ctx, *state);
2898
+ });
2899
+
2900
+ log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
2901
+ }
2902
+
2903
+ // decoder allocator
2904
+ {
2905
+ whisper_allocr_graph_init(state->alloc_decode,
2906
+ [&]() {
2907
+ const auto & hparams = ctx->model.hparams;
2908
+
2909
+ // TODO: make sure this is the worst-case scenario
2910
+ const int n_tokens = hparams.n_text_ctx;
2911
+ const int n_past = 0;
2912
+
2913
+ return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2914
+ });
2915
+
2916
+ log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2917
+ }
2918
+
2919
+ #ifdef GGML_USE_METAL
2920
+ state->ctx_metal = ggml_metal_init(1);
2921
+ if (!state->ctx_metal) {
2922
+ log("%s: ggml_metal_init() failed\n", __func__);
2923
+ delete state;
2924
+ return nullptr;
2925
+ }
2926
+
2927
+ log("%s: Metal context initialized\n", __func__);
2928
+
2929
+ // this allocates all Metal resources and memory buffers
2930
+
2931
+ void * data_ptr = NULL;
2932
+ size_t data_size = 0;
2933
+
2934
+ // TODO: add mmap support
2935
+ //if (params.use_mmap) {
2936
+ // data_ptr = ctx->model.mapping->addr;
2937
+ // data_size = ctx->model.mapping->size;
2938
+ //} else {
2939
+ // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2940
+ // data_size = ggml_get_mem_size (ctx->model.ctx);
2941
+ //}
2942
+
2943
+ data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2944
+ data_size = ggml_get_mem_size (ctx->model.ctx);
2945
+
2946
+ const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
2947
+
2948
+ log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2949
+
2950
+ #define WHISPER_METAL_CHECK_BUF(result) \
2951
+ if (!(result)) { \
2952
+ log("%s: failed to add metal buffer\n", __func__); \
2953
+ delete state; \
2954
+ return nullptr; \
2955
+ }
2956
+
2957
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2958
+
2959
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2960
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2961
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2962
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2963
+
2964
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2965
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2966
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2967
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2968
+
2969
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2970
+
2971
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2972
+ #undef WHISPER_METAL_CHECK_BUF
2973
+ #endif
2974
 
2975
  state->rng = std::mt19937(0);
2976
 
 
3027
  }
3028
 
3029
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
 
3030
  log("%s: loading model from '%s'\n", __func__, path_model);
3031
 
3032
  auto fin = std::ifstream(path_model, std::ios::binary);
 
3179
  }
3180
  #endif
3181
 
3182
+ #ifdef GGML_USE_METAL
3183
+ if (state->ctx_metal) {
3184
+ ggml_metal_free(state->ctx_metal);
3185
+ state->ctx_metal = nullptr;
3186
+ }
3187
+ #endif
3188
+
3189
  #ifdef WHISPER_USE_OPENVINO
3190
  if (state->ctx_openvino != nullptr) {
3191
  whisper_openvino_free(state->ctx_openvino);
 
3193
  }
3194
  #endif
3195
 
3196
+ whisper_allocr_free(state->alloc_conv);
3197
+ whisper_allocr_free(state->alloc_decode);
3198
+ whisper_allocr_free(state->alloc_cross);
3199
+ whisper_allocr_free(state->alloc_encode);
3200
+
3201
  delete state;
3202
  }
3203
  }
 
3618
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3619
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3620
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3621
+ const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3622
 
3623
  log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3624
  log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3625
  log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3626
  log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3627
  log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3628
+ log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3629
  }
3630
  log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3631
  }
 
3635
  ctx->state->t_sample_us = 0;
3636
  ctx->state->t_encode_us = 0;
3637
  ctx->state->t_decode_us = 0;
3638
+ ctx->state->t_prompt_us = 0;
3639
+ ctx->state->n_sample = 0;
3640
+ ctx->state->n_encode = 0;
3641
+ ctx->state->n_decode = 0;
3642
+ ctx->state->n_prompt = 0;
3643
  }
3644
  }
3645
 
 
4489
  decoder.probs.resize (ctx->vocab.n_vocab);
4490
  decoder.logits.resize (ctx->vocab.n_vocab);
4491
  decoder.logprobs.resize(ctx->vocab.n_vocab);
4492
+
4493
+ // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4494
+ #ifdef GGML_USE_METAL
4495
+ #define WHISPER_METAL_CHECK_BUF(result) \
4496
+ if (!(result)) { \
4497
+ log("%s: failed to add metal buffer\n", __func__); \
4498
+ return 0; \
4499
+ }
4500
+
4501
+ const std::string kv_name = "kv_self_" + std::to_string(j);
4502
+ auto & kv_self = decoder.kv_self;
4503
+
4504
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4505
+ #undef WHISPER_METAL_CHECK_BUF
4506
+ #endif
4507
  }
4508
  }
4509
 
 
4696
 
4697
  decoder.kv_self.n += prompt.size();
4698
 
4699
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4700
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4701
  memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
4702
  }
4703
 
 
5210
  ctx->state->t_sample_us += states[i]->t_sample_us;
5211
  ctx->state->t_encode_us += states[i]->t_encode_us;
5212
  ctx->state->t_decode_us += states[i]->t_decode_us;
5213
+ ctx->state->t_prompt_us += states[i]->t_prompt_us;
5214
+
5215
+ ctx->state->n_sample += states[i]->n_sample;
5216
+ ctx->state->n_encode += states[i]->n_encode;
5217
+ ctx->state->n_decode += states[i]->n_decode;
5218
+ ctx->state->n_prompt += states[i]->n_prompt;
5219
 
5220
  whisper_free_state(states[i]);
5221
  }
 
5412
  // b: N*N*sizeof(float)
5413
  // c: N*N*sizeof(float)
5414
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5415
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
5416
+ std::vector<uint8_t> work;
5417
 
5418
  // put a bunch of random data in the buffer
5419
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;