nalbion commited on
Commit
66cb305
·
unverified ·
1 Parent(s): 9cc4aaa

Feature/java bindings2 (#944)

Browse files

* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams

Files changed (21) hide show
  1. .github/workflows/build.yml +47 -0
  2. bindings/java/CMakeLists.txt +0 -50
  3. bindings/java/README.md +1 -11
  4. bindings/java/build.gradle +9 -1
  5. bindings/java/src/main/cpp/whisper_java.cpp +0 -33
  6. bindings/java/src/main/cpp/whisper_java.h +0 -24
  7. bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java +44 -17
  8. bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +12 -1
  9. bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java +0 -23
  10. bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java +1 -1
  11. bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java +2 -5
  12. bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java +1 -1
  13. bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java +2 -3
  14. bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java +19 -0
  15. bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java +30 -0
  16. bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java +16 -0
  17. bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +180 -54
  18. bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java +0 -7
  19. bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +38 -11
  20. whisper.cpp +14 -0
  21. whisper.h +3 -0
.github/workflows/build.yml CHANGED
@@ -125,8 +125,10 @@ jobs:
125
  include:
126
  - arch: Win32
127
  s2arc: x86
 
128
  - arch: x64
129
  s2arc: x64
 
130
  - sdl2: ON
131
  s2ver: 2.26.0
132
 
@@ -159,6 +161,12 @@ jobs:
159
  if: matrix.sdl2 == 'ON'
160
  run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
161
 
 
 
 
 
 
 
162
  - name: Upload binaries
163
  if: matrix.sdl2 == 'ON'
164
  uses: actions/upload-artifact@v1
@@ -363,3 +371,42 @@ jobs:
363
  run: |
364
  cd examples/whisper.android
365
  ./gradlew assembleRelease --no-daemon
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  include:
126
  - arch: Win32
127
  s2arc: x86
128
+ jnaPath: win32-x86
129
  - arch: x64
130
  s2arc: x64
131
+ jnaPath: win32-x86-64
132
  - sdl2: ON
133
  s2ver: 2.26.0
134
 
 
161
  if: matrix.sdl2 == 'ON'
162
  run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
163
 
164
+ - name: Upload dll
165
+ uses: actions/upload-artifact@v3
166
+ with:
167
+ name: ${{ matrix.jnaPath }}_whisper.dll
168
+ path: build/bin/${{ matrix.build }}/whisper.dll
169
+
170
  - name: Upload binaries
171
  if: matrix.sdl2 == 'ON'
172
  uses: actions/upload-artifact@v1
 
371
  run: |
372
  cd examples/whisper.android
373
  ./gradlew assembleRelease --no-daemon
374
+
375
+ java:
376
+ needs: [ 'windows' ]
377
+ runs-on: windows-latest
378
+ steps:
379
+ - uses: actions/checkout@v1
380
+
381
+ - name: Install Java
382
+ uses: actions/setup-java@v1
383
+ with:
384
+ java-version: 17
385
+
386
+ - name: Download Windows lib
387
+ uses: actions/download-artifact@v3
388
+ with:
389
+ name: win32-x86-64_whisper.dll
390
+ path: bindings/java/build/generated/resources/main/win32-x86-64
391
+
392
+ - name: Build
393
+ run: |
394
+ models\download-ggml-model.cmd tiny.en
395
+ cd bindings/java
396
+ chmod +x ./gradlew
397
+ ./gradlew build
398
+
399
+ - name: Upload jar
400
+ uses: actions/upload-artifact@v3
401
+ with:
402
+ name: whispercpp.jar
403
+ path: bindings/java/build/libs/whispercpp-*.jar
404
+
405
+ # - name: Publish package
406
+ # if: ${{ github.ref == 'refs/heads/master' }}
407
+ # uses: gradle/gradle-build-action@v2
408
+ # with:
409
+ # arguments: publish
410
+ # env:
411
+ # MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
412
+ # MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
bindings/java/CMakeLists.txt DELETED
@@ -1,50 +0,0 @@
1
- cmake_minimum_required(VERSION 3.10)
2
-
3
- project(whisper_java VERSION 1.4.2)
4
-
5
- # Set the target name and source file/s
6
- set(TARGET_NAME whisper_java)
7
- set(SOURCES src/main/cpp/whisper_java.cpp)
8
-
9
- # include <whisper.h>
10
- include_directories(../../)
11
-
12
- # Set the output directory for the DLL/shared library based on the platform as required by JNA
13
- if(WIN32)
14
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)
15
- elseif(UNIX AND NOT APPLE)
16
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)
17
- elseif(APPLE)
18
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)
19
- endif()
20
-
21
- set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})
22
- set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})
23
-
24
- # Create the whisper_java library
25
- add_library(${TARGET_NAME} SHARED ${SOURCES})
26
-
27
- # Link against ../../build/Release/whisper.dll (or so/dynlib)
28
- target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})
29
- target_link_libraries(${TARGET_NAME} PRIVATE whisper)
30
-
31
- # Set the appropriate compiler flags for Windows, Linux, and macOS
32
- if(WIN32)
33
- target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)
34
- elseif(UNIX AND NOT APPLE)
35
- target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
36
- elseif(APPLE)
37
- target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
38
- endif()
39
-
40
- target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)
41
- # add_definitions(-DWHISPER_SHARED)
42
-
43
- # Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA
44
- foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})
45
- string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)
46
- set_target_properties(${TARGET_NAME} PROPERTIES
47
- RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
48
- LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
49
- ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})
50
- endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bindings/java/README.md CHANGED
@@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
6
  * Ubuntu on x86_64
7
  * Windows on x86_64
8
 
9
- The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`.
10
-
11
- There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
12
-
13
- The most simple usage is as follows:
14
 
15
  ```java
16
  import io.github.ggerganov.whispercpp.WhisperCpp;
@@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
48
  git clone https://github.com/ggerganov/whisper.cpp.git
49
  cd whisper.cpp/bindings/java
50
 
51
- mkdir build
52
- pushd build
53
- cmake ..
54
- cmake --build .
55
- popd
56
-
57
  ./gradlew build
58
  ```
59
 
 
6
  * Ubuntu on x86_64
7
  * Windows on x86_64
8
 
9
+ The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
 
 
 
 
10
 
11
  ```java
12
  import io.github.ggerganov.whispercpp.WhisperCpp;
 
44
  git clone https://github.com/ggerganov/whisper.cpp.git
45
  cd whisper.cpp/bindings/java
46
 
 
 
 
 
 
 
47
  ./gradlew build
48
  ```
49
 
bindings/java/build.gradle CHANGED
@@ -22,6 +22,12 @@ sourceSets {
22
  }
23
  }
24
 
 
 
 
 
 
 
25
  tasks.register('copyLibwhisperSo', Copy) {
26
  from '../../build'
27
  include 'libwhisper.so'
@@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
34
  into 'build/generated/resources/main/windows-x86-64'
35
  }
36
 
37
- tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll
 
 
38
 
39
  test {
40
  systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
 
22
  }
23
  }
24
 
25
+ tasks.register('copyLibwhisperDynlib', Copy) {
26
+ from '../../build'
27
+ include 'libwhisper.dynlib'
28
+ into 'build/generated/resources/main/darwin'
29
+ }
30
+
31
  tasks.register('copyLibwhisperSo', Copy) {
32
  from '../../build'
33
  include 'libwhisper.so'
 
40
  into 'build/generated/resources/main/windows-x86-64'
41
  }
42
 
43
+ tasks.register('copyLibs') {
44
+ dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
45
+ }
46
 
47
  test {
48
  systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
bindings/java/src/main/cpp/whisper_java.cpp DELETED
@@ -1,33 +0,0 @@
1
- #include <stdio.h>
2
- #include "whisper_java.h"
3
-
4
- struct whisper_full_params default_params;
5
- struct whisper_context * whisper_ctx = nullptr;
6
-
7
- struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
8
- default_params = whisper_full_default_params(strategy);
9
-
10
- // struct whisper_java_params result = {};
11
- // return result;
12
- return;
13
- }
14
-
15
- void whisper_java_init_from_file(const char * path_model) {
16
- whisper_ctx = whisper_init_from_file(path_model);
17
- if (0 == default_params.n_threads) {
18
- whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
19
- }
20
- }
21
-
22
- /** Delegates to whisper_full, but without having to pass `whisper_full_params` */
23
- int whisper_java_full(
24
- struct whisper_context * ctx,
25
- // struct whisper_java_params params,
26
- const float * samples,
27
- int n_samples) {
28
- return whisper_full(ctx, default_params, samples, n_samples);
29
- }
30
-
31
- void whisper_java_free() {
32
- // free(default_params);
33
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bindings/java/src/main/cpp/whisper_java.h DELETED
@@ -1,24 +0,0 @@
1
- #define WHISPER_BUILD
2
- #include <whisper.h>
3
-
4
- #ifdef __cplusplus
5
- extern "C" {
6
- #endif
7
-
8
- struct whisper_java_params {
9
- };
10
-
11
- WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
12
-
13
- WHISPER_API void whisper_java_init_from_file(const char * path_model);
14
-
15
- WHISPER_API int whisper_java_full(
16
- struct whisper_context * ctx,
17
- // struct whisper_java_params params,
18
- const float * samples,
19
- int n_samples);
20
-
21
-
22
- #ifdef __cplusplus
23
- }
24
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java CHANGED
@@ -1,7 +1,8 @@
1
  package io.github.ggerganov.whispercpp;
2
 
 
3
  import com.sun.jna.Pointer;
4
- import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
5
  import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
6
 
7
  import java.io.File;
@@ -13,8 +14,9 @@ import java.io.IOException;
13
  */
14
  public class WhisperCpp implements AutoCloseable {
15
  private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
16
- private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
17
  private Pointer ctx = null;
 
 
18
 
19
  public File modelDir() {
20
  String modelDirPath = System.getenv("XDG_CACHE_HOME");
@@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
27
 
28
  /**
29
  * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
30
- * @return a Pointer to the WhisperContext
31
  */
32
- void initContext(String modelPath) throws FileNotFoundException {
33
  if (ctx != null) {
34
  lib.whisper_free(ctx);
35
  }
@@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
42
  modelPath = new File(modelDir(), modelPath).getAbsolutePath();
43
  }
44
 
45
- javaLib.whisper_java_init_from_file(modelPath);
46
  ctx = lib.whisper_init_from_file(modelPath);
47
 
48
  if (ctx == null) {
@@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
51
  }
52
 
53
  /**
54
- * Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.
55
- * `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.
 
 
 
 
56
  */
57
- public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {
58
- javaLib.whisper_java_default_params(strategy.ordinal());
59
- // return lib.whisper_full_default_params(strategy.value)
60
- }
61
 
62
- // whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params
63
- // fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {
64
- // return lib.whisper_full_default_params(strategy.value)
65
- // }
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  @Override
68
  public void close() {
69
  freeContext();
 
70
  System.out.println("Whisper closed");
71
  }
72
 
@@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
76
  }
77
  }
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  /**
80
  * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
81
  * Not thread safe for same context
82
  * Uses the specified decoding strategy to obtain the text.
83
  */
84
- public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {
85
  if (ctx == null) {
86
  throw new IllegalStateException("Model not initialised");
87
  }
88
 
89
- if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {
90
  throw new IOException("Failed to process audio");
91
  }
92
 
 
1
  package io.github.ggerganov.whispercpp;
2
 
3
+ import com.sun.jna.Native;
4
  import com.sun.jna.Pointer;
5
+ import io.github.ggerganov.whispercpp.params.WhisperFullParams;
6
  import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
7
 
8
  import java.io.File;
 
14
  */
15
  public class WhisperCpp implements AutoCloseable {
16
  private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
 
17
  private Pointer ctx = null;
18
+ private Pointer greedyPointer = null;
19
+ private Pointer beamPointer = null;
20
 
21
  public File modelDir() {
22
  String modelDirPath = System.getenv("XDG_CACHE_HOME");
 
29
 
30
  /**
31
  * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
 
32
  */
33
+ public void initContext(String modelPath) throws FileNotFoundException {
34
  if (ctx != null) {
35
  lib.whisper_free(ctx);
36
  }
 
43
  modelPath = new File(modelDir(), modelPath).getAbsolutePath();
44
  }
45
 
 
46
  ctx = lib.whisper_init_from_file(modelPath);
47
 
48
  if (ctx == null) {
 
51
  }
52
 
53
  /**
54
+ * Provides default params which can be used with `whisper_full()` etc.
55
+ * Because this function allocates memory for the params, the caller must call either:
56
+ * - call `whisper_free_params()`
57
+ * - `Native.free(Pointer.nativeValue(pointer));`
58
+ *
59
+ * @param strategy - GREEDY
60
  */
61
+ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
62
+ Pointer pointer;
 
 
63
 
64
+ // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
65
+ if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
66
+ if (greedyPointer == null) {
67
+ greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
68
+ }
69
+ pointer = greedyPointer;
70
+ } else {
71
+ if (beamPointer == null) {
72
+ beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
73
+ }
74
+ pointer = beamPointer;
75
+ }
76
+
77
+ WhisperFullParams params = new WhisperFullParams(pointer);
78
+ params.read();
79
+ return params;
80
+ }
81
 
82
  @Override
83
  public void close() {
84
  freeContext();
85
+ freeParams();
86
  System.out.println("Whisper closed");
87
  }
88
 
 
92
  }
93
  }
94
 
95
+ private void freeParams() {
96
+ if (greedyPointer != null) {
97
+ Native.free(Pointer.nativeValue(greedyPointer));
98
+ greedyPointer = null;
99
+ }
100
+ if (beamPointer != null) {
101
+ Native.free(Pointer.nativeValue(beamPointer));
102
+ beamPointer = null;
103
+ }
104
+ }
105
+
106
  /**
107
  * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
108
  * Not thread safe for same context
109
  * Uses the specified decoding strategy to obtain the text.
110
  */
111
+ public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
112
  if (ctx == null) {
113
  throw new IllegalStateException("Model not initialised");
114
  }
115
 
116
+ if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
117
  throw new IOException("Failed to process audio");
118
  }
119
 
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java CHANGED
@@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
231
  void whisper_print_timings(Pointer ctx);
232
  void whisper_reset_timings(Pointer ctx);
233
 
 
 
 
 
234
  /**
 
 
 
 
 
235
  * @param strategy - WhisperSamplingStrategy.value
236
  */
237
- WhisperFullParams whisper_full_default_params(int strategy);
 
 
238
 
239
  /**
240
  * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
 
231
  void whisper_print_timings(Pointer ctx);
232
  void whisper_reset_timings(Pointer ctx);
233
 
234
+ // Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
235
+ // when `whisper_full_default_params()` tries to return a struct.
236
+ // WhisperFullParams whisper_full_default_params(int strategy);
237
+
238
  /**
239
+ * Provides default params which can be used with `whisper_full()` etc.
240
+ * Because this function allocates memory for the params, the caller must call either:
241
+ * - call `whisper_free_params()`
242
+ * - `Native.free(Pointer.nativeValue(pointer));`
243
+ *
244
  * @param strategy - WhisperSamplingStrategy.value
245
  */
246
+ Pointer whisper_full_default_params_by_ref(int strategy);
247
+
248
+ void whisper_free_params(Pointer params);
249
 
250
  /**
251
  * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java DELETED
@@ -1,23 +0,0 @@
1
- package io.github.ggerganov.whispercpp;
2
-
3
- import com.sun.jna.Library;
4
- import com.sun.jna.Native;
5
- import com.sun.jna.Pointer;
6
- import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
7
-
8
- interface WhisperJavaJnaLibrary extends Library {
9
- WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
10
-
11
- void whisper_java_default_params(int strategy);
12
-
13
- void whisper_java_free();
14
-
15
- void whisper_java_init_from_file(String modelPath);
16
-
17
- /**
18
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
19
- * Not thread safe for same context
20
- * Uses the specified decoding strategy to obtain the text.
21
- */
22
- int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java CHANGED
@@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
20
  * @param user_data User data.
21
  * @return True if the computation should proceed, false otherwise.
22
  */
23
- boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);
24
  }
 
20
  * @param user_data User data.
21
  * @return True if the computation should proceed, false otherwise.
22
  */
23
+ boolean callback(Pointer ctx, Pointer state, Pointer user_data);
24
  }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java CHANGED
@@ -1,12 +1,9 @@
1
  package io.github.ggerganov.whispercpp.callbacks;
2
 
 
3
  import com.sun.jna.Pointer;
4
- import io.github.ggerganov.whispercpp.WhisperContext;
5
- import io.github.ggerganov.whispercpp.model.WhisperState;
6
  import io.github.ggerganov.whispercpp.model.WhisperTokenData;
7
 
8
- import javax.security.auth.callback.Callback;
9
-
10
  /**
11
  * Callback to filter logits.
12
  * Can be used to modify the logits before sampling.
@@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
24
  * @param logits The array of logits.
25
  * @param user_data User data.
26
  */
27
- void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
28
  }
 
1
  package io.github.ggerganov.whispercpp.callbacks;
2
 
3
+ import com.sun.jna.Callback;
4
  import com.sun.jna.Pointer;
 
 
5
  import io.github.ggerganov.whispercpp.model.WhisperTokenData;
6
 
 
 
7
  /**
8
  * Callback to filter logits.
9
  * Can be used to modify the logits before sampling.
 
21
  * @param logits The array of logits.
22
  * @param user_data User data.
23
  */
24
+ void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
25
  }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java CHANGED
@@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
20
  * @param n_new The number of newly generated text segments.
21
  * @param user_data User data.
22
  */
23
- void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);
24
  }
 
20
  * @param n_new The number of newly generated text segments.
21
  * @param user_data User data.
22
  */
23
+ void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
24
  }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java CHANGED
@@ -1,11 +1,10 @@
1
  package io.github.ggerganov.whispercpp.callbacks;
2
 
 
3
  import com.sun.jna.Pointer;
4
  import io.github.ggerganov.whispercpp.WhisperContext;
5
  import io.github.ggerganov.whispercpp.model.WhisperState;
6
 
7
- import javax.security.auth.callback.Callback;
8
-
9
  /**
10
  * Callback for progress updates.
11
  */
@@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
19
  * @param progress The progress value.
20
  * @param user_data User data.
21
  */
22
- void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);
23
  }
 
1
  package io.github.ggerganov.whispercpp.callbacks;
2
 
3
+ import com.sun.jna.Callback;
4
  import com.sun.jna.Pointer;
5
  import io.github.ggerganov.whispercpp.WhisperContext;
6
  import io.github.ggerganov.whispercpp.model.WhisperState;
7
 
 
 
8
  /**
9
  * Callback for progress updates.
10
  */
 
18
  * @param progress The progress value.
19
  * @param user_data User data.
20
  */
21
+ void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
22
  }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package io.github.ggerganov.whispercpp.params;
2
+
3
+ import com.sun.jna.Structure;
4
+
5
+ import java.util.Arrays;
6
+ import java.util.List;
7
+
8
+ public class BeamSearchParams extends Structure {
9
+ /** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
10
+ public int beam_size;
11
+
12
+ /** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
13
+ public float patience;
14
+
15
+ @Override
16
+ protected List<String> getFieldOrder() {
17
+ return Arrays.asList("beam_size", "patience");
18
+ }
19
+ }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package io.github.ggerganov.whispercpp.params;
2
+
3
+ import com.sun.jna.IntegerType;
4
+
5
+ import java.util.function.BooleanSupplier;
6
+
7
+ public class CBool extends IntegerType implements BooleanSupplier {
8
+ public static final int SIZE = 1;
9
+ public static final CBool FALSE = new CBool(0);
10
+ public static final CBool TRUE = new CBool(1);
11
+
12
+
13
+ public CBool() {
14
+ this(0);
15
+ }
16
+
17
+ public CBool(long value) {
18
+ super(SIZE, value, true);
19
+ }
20
+
21
+ @Override
22
+ public boolean getAsBoolean() {
23
+ return intValue() == 1;
24
+ }
25
+
26
+ @Override
27
+ public String toString() {
28
+ return intValue() == 1 ? "true" : "false";
29
+ }
30
+ }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package io.github.ggerganov.whispercpp.params;
2
+
3
+ import com.sun.jna.Structure;
4
+
5
+ import java.util.Collections;
6
+ import java.util.List;
7
+
8
+ public class GreedyParams extends Structure {
9
+ /** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
10
+ public int best_of;
11
+
12
+ @Override
13
+ protected List<String> getFieldOrder() {
14
+ return Collections.singletonList("best_of");
15
+ }
16
+ }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java CHANGED
@@ -1,13 +1,14 @@
1
  package io.github.ggerganov.whispercpp.params;
2
 
3
- import com.sun.jna.Callback;
4
- import com.sun.jna.Pointer;
5
- import com.sun.jna.Structure;
6
  import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
7
  import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
8
  import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
9
  import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
10
 
 
 
 
11
  /**
12
  * Parameters for the whisper_full() function.
13
  * If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
@@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
15
  */
16
  public class WhisperFullParams extends Structure {
17
 
 
 
 
 
 
 
18
  /** Sampling strategy for whisper_full() function. */
19
  public int strategy;
20
 
21
- /** Number of threads. */
22
  public int n_threads;
23
 
24
- /** Maximum tokens to use from past text as a prompt for the decoder. */
25
  public int n_max_text_ctx;
26
 
27
- /** Start offset in milliseconds. */
28
  public int offset_ms;
29
 
30
- /** Audio duration to process in milliseconds. */
31
  public int duration_ms;
32
 
33
- /** Translate flag. */
34
- public boolean translate;
 
 
 
 
 
35
 
36
- /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */
37
- public boolean no_context;
 
 
38
 
39
- /** Flag to force single segment output (useful for streaming). */
40
- public boolean single_segment;
41
 
42
- /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). */
43
- public boolean print_special;
 
 
44
 
45
- /** Flag to print progress information. */
46
- public boolean print_progress;
47
 
48
- /** Flag to print results from within whisper.cpp (avoid it, use callback instead). */
49
- public boolean print_realtime;
 
 
50
 
51
- /** Flag to print timestamps for each text segment when printing realtime. */
52
- public boolean print_timestamps;
53
 
54
- /** [EXPERIMENTAL] Flag to enable token-level timestamps. */
55
- public boolean token_timestamps;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  public float thold_pt;
59
 
60
  /** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
61
  public float thold_ptsum;
62
 
63
- /** Maximum segment length in characters. */
64
  public int max_len;
65
 
66
- /** Flag to split on word rather than on token (when used with max_len). */
67
- public boolean split_on_word;
 
 
 
 
 
68
 
69
- /** Maximum tokens per segment (0 = no limit). */
70
  public int max_tokens;
71
 
72
- /** Flag to speed up the audio by 2x using Phase Vocoder. */
73
- public boolean speed_up;
 
 
 
 
 
74
 
75
  /** Overwrite the audio context size (0 = use default). */
76
  public int audio_ctx;
@@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
79
  * These are prepended to any existing text context from a previous call. */
80
  public String initial_prompt;
81
 
82
- /** Prompt tokens. */
83
  public Pointer prompt_tokens;
84
 
 
 
 
 
 
 
85
  /** Number of prompt tokens. */
86
  public int prompt_n_tokens;
87
 
@@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
90
  public String language;
91
 
92
  /** Flag to indicate whether to detect language automatically. */
93
- public boolean detect_language;
 
 
 
 
 
94
 
95
- /** Common decoding parameters. */
96
 
97
  /** Flag to suppress blank tokens. */
98
- public boolean suppress_blank;
 
 
 
 
 
 
 
99
 
100
  /** Flag to suppress non-speech tokens. */
101
- public boolean suppress_non_speech_tokens;
 
 
102
 
103
  /** Initial decoding temperature. */
104
  public float temperature;
@@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
109
  /** Length penalty. */
110
  public float length_penalty;
111
 
112
- /** Fallback parameters. */
113
 
114
  /** Temperature increment. */
115
  public float temperature_inc;
@@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
123
  /** No speech threshold. */
124
  public float no_speech_thold;
125
 
126
- class GreedyParams extends Structure {
127
- /** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
128
- public int best_of;
129
- }
130
-
131
  /** Greedy decoding parameters. */
132
  public GreedyParams greedy;
133
 
134
- class BeamSearchParams extends Structure {
135
- /** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
136
- int beam_size;
137
-
138
- /** ref: https://arxiv.org/pdf/2204.05424.pdf */
139
- float patience;
140
- }
141
-
142
  /**
143
  * Beam search decoding parameters.
144
  */
145
  public BeamSearchParams beam_search;
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  /**
148
  * Callback for every newly generated text segment.
 
149
  */
150
- public WhisperNewSegmentCallback new_segment_callback;
151
 
152
  /**
153
  * User data for the new_segment_callback.
@@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
156
 
157
  /**
158
  * Callback on each progress update.
 
159
  */
160
- public WhisperProgressCallback progress_callback;
161
 
162
  /**
163
  * User data for the progress_callback.
@@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
166
 
167
  /**
168
  * Callback each time before the encoder starts.
 
169
  */
170
- public WhisperEncoderBeginCallback encoder_begin_callback;
171
 
172
  /**
173
  * User data for the encoder_begin_callback.
@@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
176
 
177
  /**
178
  * Callback by each decoder to filter obtained logits.
 
179
  */
180
- public WhisperLogitsFilterCallback logits_filter_callback;
181
 
182
  /**
183
  * User data for the logits_filter_callback.
184
  */
185
  public Pointer logits_filter_callback_user_data;
186
- }
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  package io.github.ggerganov.whispercpp.params;
2
 
3
+ import com.sun.jna.*;
 
 
4
  import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
5
  import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
6
  import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
7
  import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
8
 
9
+ import java.util.Arrays;
10
+ import java.util.List;
11
+
12
  /**
13
  * Parameters for the whisper_full() function.
14
  * If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
 
16
  */
17
  public class WhisperFullParams extends Structure {
18
 
19
+ public WhisperFullParams(Pointer p) {
20
+ super(p);
21
+ // super(p, ALIGN_MSVC);
22
+ // super(p, ALIGN_GNUC);
23
+ }
24
+
25
  /** Sampling strategy for whisper_full() function. */
26
  public int strategy;
27
 
28
+ /** Number of threads. (default = 4) */
29
  public int n_threads;
30
 
31
+ /** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
32
  public int n_max_text_ctx;
33
 
34
+ /** Start offset in milliseconds. (default = 0) */
35
  public int offset_ms;
36
 
37
+ /** Audio duration to process in milliseconds. (default = 0) */
38
  public int duration_ms;
39
 
40
+ /** Translate flag. (default = false) */
41
+ public CBool translate;
42
+
43
+ /** The compliment of translateMode() */
44
+ public void transcribeMode() {
45
+ translate = CBool.FALSE;
46
+ }
47
 
48
+ /** The compliment of transcribeMode() */
49
+ public void translateMode() {
50
+ translate = CBool.TRUE;
51
+ }
52
 
53
+ /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
54
+ public CBool no_context;
55
 
56
+ /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
57
+ public void enableContext(boolean enable) {
58
+ no_context = enable ? CBool.FALSE : CBool.TRUE;
59
+ }
60
 
61
+ /** Flag to force single segment output (useful for streaming). (default = false) */
62
+ public CBool single_segment;
63
 
64
+ /** Flag to force single segment output (useful for streaming). (default = false) */
65
+ public void singleSegment(boolean single) {
66
+ single_segment = single ? CBool.TRUE : CBool.FALSE;
67
+ }
68
 
69
+ /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
70
+ public CBool print_special;
71
 
72
+ /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
73
+ public void printSpecial(boolean enable) {
74
+ print_special = enable ? CBool.TRUE : CBool.FALSE;
75
+ }
76
+
77
+ /** Flag to print progress information. (default = true) */
78
+ public CBool print_progress;
79
+
80
+ /** Flag to print progress information. (default = true) */
81
+ public void printProgress(boolean enable) {
82
+ print_progress = enable ? CBool.TRUE : CBool.FALSE;
83
+ }
84
+
85
+ /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
86
+ public CBool print_realtime;
87
+
88
+ /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
89
+ public void printRealtime(boolean enable) {
90
+ print_realtime = enable ? CBool.TRUE : CBool.FALSE;
91
+ }
92
 
93
+ /** Flag to print timestamps for each text segment when printing realtime. (default = true) */
94
+ public CBool print_timestamps;
95
+
96
+ /** Flag to print timestamps for each text segment when printing realtime. (default = true) */
97
+ public void printTimestamps(boolean enable) {
98
+ print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
99
+ }
100
+
101
+ /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
102
+ public CBool token_timestamps;
103
+
104
+ /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
105
+ public void tokenTimestamps(boolean enable) {
106
+ token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
107
+ }
108
+
109
+ /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
110
  public float thold_pt;
111
 
112
  /** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
113
  public float thold_ptsum;
114
 
115
+ /** Maximum segment length in characters. (default = 0) */
116
  public int max_len;
117
 
118
+ /** Flag to split on word rather than on token (when used with max_len). (default = false) */
119
+ public CBool split_on_word;
120
+
121
+ /** Flag to split on word rather than on token (when used with max_len). (default = false) */
122
+ public void splitOnWord(boolean enable) {
123
+ split_on_word = enable ? CBool.TRUE : CBool.FALSE;
124
+ }
125
 
126
+ /** Maximum tokens per segment (0, default = no limit) */
127
  public int max_tokens;
128
 
129
+ /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
130
+ public CBool speed_up;
131
+
132
+ /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
133
+ public void speedUp(boolean enable) {
134
+ speed_up = enable ? CBool.TRUE : CBool.FALSE;
135
+ }
136
 
137
  /** Overwrite the audio context size (0 = use default). */
138
  public int audio_ctx;
 
141
  * These are prepended to any existing text context from a previous call. */
142
  public String initial_prompt;
143
 
144
+ /** Prompt tokens. (int*) */
145
  public Pointer prompt_tokens;
146
 
147
+ public void setPromptTokens(int[] tokens) {
148
+ Memory mem = new Memory(tokens.length * 4L);
149
+ mem.write(0, tokens, 0, tokens.length);
150
+ prompt_tokens = mem;
151
+ }
152
+
153
  /** Number of prompt tokens. */
154
  public int prompt_n_tokens;
155
 
 
158
  public String language;
159
 
160
  /** Flag to indicate whether to detect language automatically. */
161
+ public CBool detect_language;
162
+
163
+ /** Flag to indicate whether to detect language automatically. */
164
+ public void detectLanguage(boolean enable) {
165
+ detect_language = enable ? CBool.TRUE : CBool.FALSE;
166
+ }
167
 
168
+ // Common decoding parameters.
169
 
170
  /** Flag to suppress blank tokens. */
171
+ public CBool suppress_blank;
172
+
173
+ public void suppressBlanks(boolean enable) {
174
+ suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
175
+ }
176
+
177
+ /** Flag to suppress non-speech tokens. */
178
+ public CBool suppress_non_speech_tokens;
179
 
180
  /** Flag to suppress non-speech tokens. */
181
+ public void suppressNonSpeechTokens(boolean enable) {
182
+ suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
183
+ }
184
 
185
  /** Initial decoding temperature. */
186
  public float temperature;
 
191
  /** Length penalty. */
192
  public float length_penalty;
193
 
194
+ // Fallback parameters.
195
 
196
  /** Temperature increment. */
197
  public float temperature_inc;
 
205
  /** No speech threshold. */
206
  public float no_speech_thold;
207
 
 
 
 
 
 
208
  /** Greedy decoding parameters. */
209
  public GreedyParams greedy;
210
 
 
 
 
 
 
 
 
 
211
  /**
212
  * Beam search decoding parameters.
213
  */
214
  public BeamSearchParams beam_search;
215
 
216
+ public void setBestOf(int bestOf) {
217
+ if (greedy == null) {
218
+ greedy = new GreedyParams();
219
+ }
220
+ greedy.best_of = bestOf;
221
+ }
222
+
223
+ public void setBeamSize(int beamSize) {
224
+ if (beam_search == null) {
225
+ beam_search = new BeamSearchParams();
226
+ }
227
+ beam_search.beam_size = beamSize;
228
+ }
229
+
230
+ public void setBeamSizeAndPatience(int beamSize, float patience) {
231
+ if (beam_search == null) {
232
+ beam_search = new BeamSearchParams();
233
+ }
234
+ beam_search.beam_size = beamSize;
235
+ beam_search.patience = patience;
236
+ }
237
+
238
  /**
239
  * Callback for every newly generated text segment.
240
+ * WhisperNewSegmentCallback
241
  */
242
+ public Pointer new_segment_callback;
243
 
244
  /**
245
  * User data for the new_segment_callback.
 
248
 
249
  /**
250
  * Callback on each progress update.
251
+ * WhisperProgressCallback
252
  */
253
+ public Pointer progress_callback;
254
 
255
  /**
256
  * User data for the progress_callback.
 
259
 
260
  /**
261
  * Callback each time before the encoder starts.
262
+ * WhisperEncoderBeginCallback
263
  */
264
+ public Pointer encoder_begin_callback;
265
 
266
  /**
267
  * User data for the encoder_begin_callback.
 
270
 
271
  /**
272
  * Callback by each decoder to filter obtained logits.
273
+ * WhisperLogitsFilterCallback
274
  */
275
+ public Pointer logits_filter_callback;
276
 
277
  /**
278
  * User data for the logits_filter_callback.
279
  */
280
  public Pointer logits_filter_callback_user_data;
 
281
 
282
+
283
+ public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
284
+ new_segment_callback = CallbackReference.getFunctionPointer(callback);
285
+ }
286
+
287
+ public void setProgressCallback(WhisperProgressCallback callback) {
288
+ progress_callback = CallbackReference.getFunctionPointer(callback);
289
+ }
290
+
291
+ public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
292
+ encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
293
+ }
294
+
295
+ public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
296
+ logits_filter_callback = CallbackReference.getFunctionPointer(callback);
297
+ }
298
+
299
+ @Override
300
+ protected List<String> getFieldOrder() {
301
+ return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
302
+ "no_context", "single_segment",
303
+ "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
304
+ "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
305
+ "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
306
+ "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
307
+ "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
308
+ "new_segment_callback", "new_segment_callback_user_data",
309
+ "progress_callback", "progress_callback_user_data",
310
+ "encoder_begin_callback", "encoder_begin_callback_user_data",
311
+ "logits_filter_callback", "logits_filter_callback_user_data");
312
+ }
313
+ }
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java DELETED
@@ -1,7 +0,0 @@
1
- package io.github.ggerganov.whispercpp.params;
2
-
3
- import com.sun.jna.Structure;
4
-
5
- public class WhisperJavaParams extends Structure {
6
-
7
- }
 
 
 
 
 
 
 
 
bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java CHANGED
@@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
2
 
3
  import static org.junit.jupiter.api.Assertions.*;
4
 
5
- import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
 
6
  import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
7
  import org.junit.jupiter.api.BeforeAll;
8
  import org.junit.jupiter.api.Test;
@@ -19,11 +20,11 @@ class WhisperCppTest {
19
  static void init() throws FileNotFoundException {
20
  // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
21
  // or you can provide the absolute path to the model file.
22
- String modelName = "base.en";
23
  try {
24
  whisper.initContext(modelName);
25
- whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
26
- // whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
27
  modelInitialised = true;
28
  } catch (FileNotFoundException ex) {
29
  System.out.println("Model " + modelName + " not found");
@@ -31,11 +32,30 @@ class WhisperCppTest {
31
  }
32
 
33
  @Test
34
- void testGetDefaultJavaParams() {
35
  // When
36
- whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
37
 
38
- // Then if it doesn't throw we've connected to whisper.cpp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
40
 
41
  @Test
@@ -52,6 +72,13 @@ class WhisperCppTest {
52
  byte[] b = new byte[audioInputStream.available()];
53
  float[] floats = new float[b.length / 2];
54
 
 
 
 
 
 
 
 
55
  try {
56
  audioInputStream.read(b);
57
 
@@ -61,13 +88,13 @@ class WhisperCppTest {
61
  }
62
 
63
  // When
64
- String result = whisper.fullTranscribe(/*params,*/ floats);
65
 
66
  // Then
67
- System.out.println(result);
68
- assertEquals("And so my fellow Americans, ask not what your country can do for you, " +
69
  "ask what you can do for your country.",
70
- result);
71
  } finally {
72
  audioInputStream.close();
73
  }
 
2
 
3
  import static org.junit.jupiter.api.Assertions.*;
4
 
5
+ import io.github.ggerganov.whispercpp.params.CBool;
6
+ import io.github.ggerganov.whispercpp.params.WhisperFullParams;
7
  import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
8
  import org.junit.jupiter.api.BeforeAll;
9
  import org.junit.jupiter.api.Test;
 
20
  static void init() throws FileNotFoundException {
21
  // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
22
  // or you can provide the absolute path to the model file.
23
+ String modelName = "../../models/ggml-tiny.en.bin";
24
  try {
25
  whisper.initContext(modelName);
26
+ // whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
27
+ // whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
28
  modelInitialised = true;
29
  } catch (FileNotFoundException ex) {
30
  System.out.println("Model " + modelName + " not found");
 
32
  }
33
 
34
  @Test
35
+ void testGetDefaultFullParams_BeamSearch() {
36
  // When
37
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
38
 
39
+ // Then
40
+ assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
41
+ assertNotEquals(0, params.n_threads);
42
+ assertEquals(16384, params.n_max_text_ctx);
43
+ assertFalse(params.translate);
44
+ assertEquals(0.01f, params.thold_pt);
45
+ assertEquals(2, params.beam_search.beam_size);
46
+ assertEquals(-1.0f, params.beam_search.patience);
47
+ }
48
+
49
+ @Test
50
+ void testGetDefaultFullParams_Greedy() {
51
+ // When
52
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
53
+
54
+ // Then
55
+ assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
56
+ assertNotEquals(0, params.n_threads);
57
+ assertEquals(16384, params.n_max_text_ctx);
58
+ assertEquals(2, params.greedy.best_of);
59
  }
60
 
61
  @Test
 
72
  byte[] b = new byte[audioInputStream.available()];
73
  float[] floats = new float[b.length / 2];
74
 
75
+ // WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
76
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
77
+ params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
78
+ params.print_progress = CBool.FALSE;
79
+ // params.initial_prompt = "and so my fellow Americans um, like";
80
+
81
+
82
  try {
83
  audioInputStream.read(b);
84
 
 
88
  }
89
 
90
  // When
91
+ String result = whisper.fullTranscribe(params, floats);
92
 
93
  // Then
94
+ System.err.println(result);
95
+ assertEquals("And so my fellow Americans ask not what your country can do for you " +
96
  "ask what you can do for your country.",
97
+ result.replace(",", ""));
98
  } finally {
99
  audioInputStream.close();
100
  }
whisper.cpp CHANGED
@@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
2852
  }
2853
  }
2854
 
 
 
 
 
 
 
2855
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2856
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
2857
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
@@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
3285
 
3286
  ////////////////////////////////////////////////////////////////////////////
3287
 
 
 
 
 
 
 
 
 
3288
  struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
3289
  struct whisper_full_params result = {
3290
  /*.strategy =*/ strategy,
 
2852
  }
2853
  }
2854
 
2855
+ void whisper_free_params(struct whisper_full_params * params) {
2856
+ if (params) {
2857
+ delete params;
2858
+ }
2859
+ }
2860
+
2861
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2862
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
2863
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
 
3291
 
3292
  ////////////////////////////////////////////////////////////////////////////
3293
 
3294
+ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
3295
+ struct whisper_full_params params = whisper_full_default_params(strategy);
3296
+
3297
+ struct whisper_full_params* result = new whisper_full_params();
3298
+ *result = params;
3299
+ return result;
3300
+ }
3301
+
3302
  struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
3303
  struct whisper_full_params result = {
3304
  /*.strategy =*/ strategy,
whisper.h CHANGED
@@ -113,6 +113,7 @@ extern "C" {
113
  // Frees all allocated memory
114
  WHISPER_API void whisper_free (struct whisper_context * ctx);
115
  WHISPER_API void whisper_free_state(struct whisper_state * state);
 
116
 
117
  // Convert RAW PCM audio to log mel spectrogram.
118
  // The resulting spectrogram is stored inside the default state of the provided whisper context.
@@ -409,6 +410,8 @@ extern "C" {
409
  void * logits_filter_callback_user_data;
410
  };
411
 
 
 
412
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
413
 
414
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
 
113
  // Frees all allocated memory
114
  WHISPER_API void whisper_free (struct whisper_context * ctx);
115
  WHISPER_API void whisper_free_state(struct whisper_state * state);
116
+ WHISPER_API void whisper_free_params(struct whisper_full_params * params);
117
 
118
  // Convert RAW PCM audio to log mel spectrogram.
119
  // The resulting spectrogram is stored inside the default state of the provided whisper context.
 
410
  void * logits_filter_callback_user_data;
411
  };
412
 
413
+ // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
414
+ WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
415
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
416
 
417
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text