Spaces:
Running
Running
Feature/java bindings2 (#944)
Browse files* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams
- .github/workflows/build.yml +47 -0
- bindings/java/CMakeLists.txt +0 -50
- bindings/java/README.md +1 -11
- bindings/java/build.gradle +9 -1
- bindings/java/src/main/cpp/whisper_java.cpp +0 -33
- bindings/java/src/main/cpp/whisper_java.h +0 -24
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java +44 -17
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +12 -1
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java +0 -23
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java +1 -1
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java +2 -5
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java +1 -1
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java +2 -3
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java +19 -0
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java +30 -0
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java +16 -0
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +180 -54
- bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java +0 -7
- bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +38 -11
- whisper.cpp +14 -0
- whisper.h +3 -0
.github/workflows/build.yml
CHANGED
|
@@ -125,8 +125,10 @@ jobs:
|
|
| 125 |
include:
|
| 126 |
- arch: Win32
|
| 127 |
s2arc: x86
|
|
|
|
| 128 |
- arch: x64
|
| 129 |
s2arc: x64
|
|
|
|
| 130 |
- sdl2: ON
|
| 131 |
s2ver: 2.26.0
|
| 132 |
|
|
@@ -159,6 +161,12 @@ jobs:
|
|
| 159 |
if: matrix.sdl2 == 'ON'
|
| 160 |
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
- name: Upload binaries
|
| 163 |
if: matrix.sdl2 == 'ON'
|
| 164 |
uses: actions/upload-artifact@v1
|
|
@@ -363,3 +371,42 @@ jobs:
|
|
| 363 |
run: |
|
| 364 |
cd examples/whisper.android
|
| 365 |
./gradlew assembleRelease --no-daemon
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
include:
|
| 126 |
- arch: Win32
|
| 127 |
s2arc: x86
|
| 128 |
+
jnaPath: win32-x86
|
| 129 |
- arch: x64
|
| 130 |
s2arc: x64
|
| 131 |
+
jnaPath: win32-x86-64
|
| 132 |
- sdl2: ON
|
| 133 |
s2ver: 2.26.0
|
| 134 |
|
|
|
|
| 161 |
if: matrix.sdl2 == 'ON'
|
| 162 |
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
| 163 |
|
| 164 |
+
- name: Upload dll
|
| 165 |
+
uses: actions/upload-artifact@v3
|
| 166 |
+
with:
|
| 167 |
+
name: ${{ matrix.jnaPath }}_whisper.dll
|
| 168 |
+
path: build/bin/${{ matrix.build }}/whisper.dll
|
| 169 |
+
|
| 170 |
- name: Upload binaries
|
| 171 |
if: matrix.sdl2 == 'ON'
|
| 172 |
uses: actions/upload-artifact@v1
|
|
|
|
| 371 |
run: |
|
| 372 |
cd examples/whisper.android
|
| 373 |
./gradlew assembleRelease --no-daemon
|
| 374 |
+
|
| 375 |
+
java:
|
| 376 |
+
needs: [ 'windows' ]
|
| 377 |
+
runs-on: windows-latest
|
| 378 |
+
steps:
|
| 379 |
+
- uses: actions/checkout@v1
|
| 380 |
+
|
| 381 |
+
- name: Install Java
|
| 382 |
+
uses: actions/setup-java@v1
|
| 383 |
+
with:
|
| 384 |
+
java-version: 17
|
| 385 |
+
|
| 386 |
+
- name: Download Windows lib
|
| 387 |
+
uses: actions/download-artifact@v3
|
| 388 |
+
with:
|
| 389 |
+
name: win32-x86-64_whisper.dll
|
| 390 |
+
path: bindings/java/build/generated/resources/main/win32-x86-64
|
| 391 |
+
|
| 392 |
+
- name: Build
|
| 393 |
+
run: |
|
| 394 |
+
models\download-ggml-model.cmd tiny.en
|
| 395 |
+
cd bindings/java
|
| 396 |
+
chmod +x ./gradlew
|
| 397 |
+
./gradlew build
|
| 398 |
+
|
| 399 |
+
- name: Upload jar
|
| 400 |
+
uses: actions/upload-artifact@v3
|
| 401 |
+
with:
|
| 402 |
+
name: whispercpp.jar
|
| 403 |
+
path: bindings/java/build/libs/whispercpp-*.jar
|
| 404 |
+
|
| 405 |
+
# - name: Publish package
|
| 406 |
+
# if: ${{ github.ref == 'refs/heads/master' }}
|
| 407 |
+
# uses: gradle/gradle-build-action@v2
|
| 408 |
+
# with:
|
| 409 |
+
# arguments: publish
|
| 410 |
+
# env:
|
| 411 |
+
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
|
| 412 |
+
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
|
bindings/java/CMakeLists.txt
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
cmake_minimum_required(VERSION 3.10)
|
| 2 |
-
|
| 3 |
-
project(whisper_java VERSION 1.4.2)
|
| 4 |
-
|
| 5 |
-
# Set the target name and source file/s
|
| 6 |
-
set(TARGET_NAME whisper_java)
|
| 7 |
-
set(SOURCES src/main/cpp/whisper_java.cpp)
|
| 8 |
-
|
| 9 |
-
# include <whisper.h>
|
| 10 |
-
include_directories(../../)
|
| 11 |
-
|
| 12 |
-
# Set the output directory for the DLL/shared library based on the platform as required by JNA
|
| 13 |
-
if(WIN32)
|
| 14 |
-
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)
|
| 15 |
-
elseif(UNIX AND NOT APPLE)
|
| 16 |
-
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)
|
| 17 |
-
elseif(APPLE)
|
| 18 |
-
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)
|
| 19 |
-
endif()
|
| 20 |
-
|
| 21 |
-
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
| 22 |
-
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
| 23 |
-
|
| 24 |
-
# Create the whisper_java library
|
| 25 |
-
add_library(${TARGET_NAME} SHARED ${SOURCES})
|
| 26 |
-
|
| 27 |
-
# Link against ../../build/Release/whisper.dll (or so/dynlib)
|
| 28 |
-
target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})
|
| 29 |
-
target_link_libraries(${TARGET_NAME} PRIVATE whisper)
|
| 30 |
-
|
| 31 |
-
# Set the appropriate compiler flags for Windows, Linux, and macOS
|
| 32 |
-
if(WIN32)
|
| 33 |
-
target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)
|
| 34 |
-
elseif(UNIX AND NOT APPLE)
|
| 35 |
-
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
| 36 |
-
elseif(APPLE)
|
| 37 |
-
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
| 38 |
-
endif()
|
| 39 |
-
|
| 40 |
-
target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)
|
| 41 |
-
# add_definitions(-DWHISPER_SHARED)
|
| 42 |
-
|
| 43 |
-
# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA
|
| 44 |
-
foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})
|
| 45 |
-
string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)
|
| 46 |
-
set_target_properties(${TARGET_NAME} PROPERTIES
|
| 47 |
-
RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
| 48 |
-
LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
| 49 |
-
ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})
|
| 50 |
-
endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bindings/java/README.md
CHANGED
|
@@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
|
|
| 6 |
* Ubuntu on x86_64
|
| 7 |
* Windows on x86_64
|
| 8 |
|
| 9 |
-
The "low level" bindings are in `WhisperCppJnaLibrary
|
| 10 |
-
|
| 11 |
-
There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
|
| 12 |
-
|
| 13 |
-
The most simple usage is as follows:
|
| 14 |
|
| 15 |
```java
|
| 16 |
import io.github.ggerganov.whispercpp.WhisperCpp;
|
|
@@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
|
|
| 48 |
git clone https://github.com/ggerganov/whisper.cpp.git
|
| 49 |
cd whisper.cpp/bindings/java
|
| 50 |
|
| 51 |
-
mkdir build
|
| 52 |
-
pushd build
|
| 53 |
-
cmake ..
|
| 54 |
-
cmake --build .
|
| 55 |
-
popd
|
| 56 |
-
|
| 57 |
./gradlew build
|
| 58 |
```
|
| 59 |
|
|
|
|
| 6 |
* Ubuntu on x86_64
|
| 7 |
* Windows on x86_64
|
| 8 |
|
| 9 |
+
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
```java
|
| 12 |
import io.github.ggerganov.whispercpp.WhisperCpp;
|
|
|
|
| 44 |
git clone https://github.com/ggerganov/whisper.cpp.git
|
| 45 |
cd whisper.cpp/bindings/java
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
./gradlew build
|
| 48 |
```
|
| 49 |
|
bindings/java/build.gradle
CHANGED
|
@@ -22,6 +22,12 @@ sourceSets {
|
|
| 22 |
}
|
| 23 |
}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
tasks.register('copyLibwhisperSo', Copy) {
|
| 26 |
from '../../build'
|
| 27 |
include 'libwhisper.so'
|
|
@@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
|
|
| 34 |
into 'build/generated/resources/main/windows-x86-64'
|
| 35 |
}
|
| 36 |
|
| 37 |
-
tasks.
|
|
|
|
|
|
|
| 38 |
|
| 39 |
test {
|
| 40 |
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
|
|
|
|
| 22 |
}
|
| 23 |
}
|
| 24 |
|
| 25 |
+
tasks.register('copyLibwhisperDynlib', Copy) {
|
| 26 |
+
from '../../build'
|
| 27 |
+
include 'libwhisper.dynlib'
|
| 28 |
+
into 'build/generated/resources/main/darwin'
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
tasks.register('copyLibwhisperSo', Copy) {
|
| 32 |
from '../../build'
|
| 33 |
include 'libwhisper.so'
|
|
|
|
| 40 |
into 'build/generated/resources/main/windows-x86-64'
|
| 41 |
}
|
| 42 |
|
| 43 |
+
tasks.register('copyLibs') {
|
| 44 |
+
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
|
| 45 |
+
}
|
| 46 |
|
| 47 |
test {
|
| 48 |
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
|
bindings/java/src/main/cpp/whisper_java.cpp
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
#include <stdio.h>
|
| 2 |
-
#include "whisper_java.h"
|
| 3 |
-
|
| 4 |
-
struct whisper_full_params default_params;
|
| 5 |
-
struct whisper_context * whisper_ctx = nullptr;
|
| 6 |
-
|
| 7 |
-
struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
|
| 8 |
-
default_params = whisper_full_default_params(strategy);
|
| 9 |
-
|
| 10 |
-
// struct whisper_java_params result = {};
|
| 11 |
-
// return result;
|
| 12 |
-
return;
|
| 13 |
-
}
|
| 14 |
-
|
| 15 |
-
void whisper_java_init_from_file(const char * path_model) {
|
| 16 |
-
whisper_ctx = whisper_init_from_file(path_model);
|
| 17 |
-
if (0 == default_params.n_threads) {
|
| 18 |
-
whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
|
| 19 |
-
}
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
/** Delegates to whisper_full, but without having to pass `whisper_full_params` */
|
| 23 |
-
int whisper_java_full(
|
| 24 |
-
struct whisper_context * ctx,
|
| 25 |
-
// struct whisper_java_params params,
|
| 26 |
-
const float * samples,
|
| 27 |
-
int n_samples) {
|
| 28 |
-
return whisper_full(ctx, default_params, samples, n_samples);
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
void whisper_java_free() {
|
| 32 |
-
// free(default_params);
|
| 33 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bindings/java/src/main/cpp/whisper_java.h
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
#define WHISPER_BUILD
|
| 2 |
-
#include <whisper.h>
|
| 3 |
-
|
| 4 |
-
#ifdef __cplusplus
|
| 5 |
-
extern "C" {
|
| 6 |
-
#endif
|
| 7 |
-
|
| 8 |
-
struct whisper_java_params {
|
| 9 |
-
};
|
| 10 |
-
|
| 11 |
-
WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
|
| 12 |
-
|
| 13 |
-
WHISPER_API void whisper_java_init_from_file(const char * path_model);
|
| 14 |
-
|
| 15 |
-
WHISPER_API int whisper_java_full(
|
| 16 |
-
struct whisper_context * ctx,
|
| 17 |
-
// struct whisper_java_params params,
|
| 18 |
-
const float * samples,
|
| 19 |
-
int n_samples);
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
#ifdef __cplusplus
|
| 23 |
-
}
|
| 24 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
package io.github.ggerganov.whispercpp;
|
| 2 |
|
|
|
|
| 3 |
import com.sun.jna.Pointer;
|
| 4 |
-
import io.github.ggerganov.whispercpp.params.
|
| 5 |
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
| 6 |
|
| 7 |
import java.io.File;
|
|
@@ -13,8 +14,9 @@ import java.io.IOException;
|
|
| 13 |
*/
|
| 14 |
public class WhisperCpp implements AutoCloseable {
|
| 15 |
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
| 16 |
-
private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
|
| 17 |
private Pointer ctx = null;
|
|
|
|
|
|
|
| 18 |
|
| 19 |
public File modelDir() {
|
| 20 |
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
|
@@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
|
|
| 27 |
|
| 28 |
/**
|
| 29 |
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
| 30 |
-
* @return a Pointer to the WhisperContext
|
| 31 |
*/
|
| 32 |
-
void initContext(String modelPath) throws FileNotFoundException {
|
| 33 |
if (ctx != null) {
|
| 34 |
lib.whisper_free(ctx);
|
| 35 |
}
|
|
@@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
|
|
| 42 |
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
| 43 |
}
|
| 44 |
|
| 45 |
-
javaLib.whisper_java_init_from_file(modelPath);
|
| 46 |
ctx = lib.whisper_init_from_file(modelPath);
|
| 47 |
|
| 48 |
if (ctx == null) {
|
|
@@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
|
|
| 51 |
}
|
| 52 |
|
| 53 |
/**
|
| 54 |
-
*
|
| 55 |
-
*
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
*/
|
| 57 |
-
public
|
| 58 |
-
|
| 59 |
-
// return lib.whisper_full_default_params(strategy.value)
|
| 60 |
-
}
|
| 61 |
|
| 62 |
-
//
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
@Override
|
| 68 |
public void close() {
|
| 69 |
freeContext();
|
|
|
|
| 70 |
System.out.println("Whisper closed");
|
| 71 |
}
|
| 72 |
|
|
@@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
|
|
| 76 |
}
|
| 77 |
}
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
/**
|
| 80 |
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
| 81 |
* Not thread safe for same context
|
| 82 |
* Uses the specified decoding strategy to obtain the text.
|
| 83 |
*/
|
| 84 |
-
public String fullTranscribe(
|
| 85 |
if (ctx == null) {
|
| 86 |
throw new IllegalStateException("Model not initialised");
|
| 87 |
}
|
| 88 |
|
| 89 |
-
if (
|
| 90 |
throw new IOException("Failed to process audio");
|
| 91 |
}
|
| 92 |
|
|
|
|
| 1 |
package io.github.ggerganov.whispercpp;
|
| 2 |
|
| 3 |
+
import com.sun.jna.Native;
|
| 4 |
import com.sun.jna.Pointer;
|
| 5 |
+
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
| 6 |
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
| 7 |
|
| 8 |
import java.io.File;
|
|
|
|
| 14 |
*/
|
| 15 |
public class WhisperCpp implements AutoCloseable {
|
| 16 |
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
|
|
|
| 17 |
private Pointer ctx = null;
|
| 18 |
+
private Pointer greedyPointer = null;
|
| 19 |
+
private Pointer beamPointer = null;
|
| 20 |
|
| 21 |
public File modelDir() {
|
| 22 |
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
|
|
|
| 29 |
|
| 30 |
/**
|
| 31 |
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
|
|
|
| 32 |
*/
|
| 33 |
+
public void initContext(String modelPath) throws FileNotFoundException {
|
| 34 |
if (ctx != null) {
|
| 35 |
lib.whisper_free(ctx);
|
| 36 |
}
|
|
|
|
| 43 |
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
| 44 |
}
|
| 45 |
|
|
|
|
| 46 |
ctx = lib.whisper_init_from_file(modelPath);
|
| 47 |
|
| 48 |
if (ctx == null) {
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
/**
|
| 54 |
+
* Provides default params which can be used with `whisper_full()` etc.
|
| 55 |
+
* Because this function allocates memory for the params, the caller must call either:
|
| 56 |
+
* - call `whisper_free_params()`
|
| 57 |
+
* - `Native.free(Pointer.nativeValue(pointer));`
|
| 58 |
+
*
|
| 59 |
+
* @param strategy - GREEDY
|
| 60 |
*/
|
| 61 |
+
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
|
| 62 |
+
Pointer pointer;
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
|
| 65 |
+
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
|
| 66 |
+
if (greedyPointer == null) {
|
| 67 |
+
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
| 68 |
+
}
|
| 69 |
+
pointer = greedyPointer;
|
| 70 |
+
} else {
|
| 71 |
+
if (beamPointer == null) {
|
| 72 |
+
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
| 73 |
+
}
|
| 74 |
+
pointer = beamPointer;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
WhisperFullParams params = new WhisperFullParams(pointer);
|
| 78 |
+
params.read();
|
| 79 |
+
return params;
|
| 80 |
+
}
|
| 81 |
|
| 82 |
@Override
|
| 83 |
public void close() {
|
| 84 |
freeContext();
|
| 85 |
+
freeParams();
|
| 86 |
System.out.println("Whisper closed");
|
| 87 |
}
|
| 88 |
|
|
|
|
| 92 |
}
|
| 93 |
}
|
| 94 |
|
| 95 |
+
private void freeParams() {
|
| 96 |
+
if (greedyPointer != null) {
|
| 97 |
+
Native.free(Pointer.nativeValue(greedyPointer));
|
| 98 |
+
greedyPointer = null;
|
| 99 |
+
}
|
| 100 |
+
if (beamPointer != null) {
|
| 101 |
+
Native.free(Pointer.nativeValue(beamPointer));
|
| 102 |
+
beamPointer = null;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
/**
|
| 107 |
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
| 108 |
* Not thread safe for same context
|
| 109 |
* Uses the specified decoding strategy to obtain the text.
|
| 110 |
*/
|
| 111 |
+
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
| 112 |
if (ctx == null) {
|
| 113 |
throw new IllegalStateException("Model not initialised");
|
| 114 |
}
|
| 115 |
|
| 116 |
+
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
| 117 |
throw new IOException("Failed to process audio");
|
| 118 |
}
|
| 119 |
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java
CHANGED
|
@@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
|
|
| 231 |
void whisper_print_timings(Pointer ctx);
|
| 232 |
void whisper_reset_timings(Pointer ctx);
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
/**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
* @param strategy - WhisperSamplingStrategy.value
|
| 236 |
*/
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
|
| 239 |
/**
|
| 240 |
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
|
|
|
| 231 |
void whisper_print_timings(Pointer ctx);
|
| 232 |
void whisper_reset_timings(Pointer ctx);
|
| 233 |
|
| 234 |
+
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
|
| 235 |
+
// when `whisper_full_default_params()` tries to return a struct.
|
| 236 |
+
// WhisperFullParams whisper_full_default_params(int strategy);
|
| 237 |
+
|
| 238 |
/**
|
| 239 |
+
* Provides default params which can be used with `whisper_full()` etc.
|
| 240 |
+
* Because this function allocates memory for the params, the caller must call either:
|
| 241 |
+
* - call `whisper_free_params()`
|
| 242 |
+
* - `Native.free(Pointer.nativeValue(pointer));`
|
| 243 |
+
*
|
| 244 |
* @param strategy - WhisperSamplingStrategy.value
|
| 245 |
*/
|
| 246 |
+
Pointer whisper_full_default_params_by_ref(int strategy);
|
| 247 |
+
|
| 248 |
+
void whisper_free_params(Pointer params);
|
| 249 |
|
| 250 |
/**
|
| 251 |
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
package io.github.ggerganov.whispercpp;
|
| 2 |
-
|
| 3 |
-
import com.sun.jna.Library;
|
| 4 |
-
import com.sun.jna.Native;
|
| 5 |
-
import com.sun.jna.Pointer;
|
| 6 |
-
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
| 7 |
-
|
| 8 |
-
interface WhisperJavaJnaLibrary extends Library {
|
| 9 |
-
WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
|
| 10 |
-
|
| 11 |
-
void whisper_java_default_params(int strategy);
|
| 12 |
-
|
| 13 |
-
void whisper_java_free();
|
| 14 |
-
|
| 15 |
-
void whisper_java_init_from_file(String modelPath);
|
| 16 |
-
|
| 17 |
-
/**
|
| 18 |
-
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
| 19 |
-
* Not thread safe for same context
|
| 20 |
-
* Uses the specified decoding strategy to obtain the text.
|
| 21 |
-
*/
|
| 22 |
-
int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
|
| 23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java
CHANGED
|
@@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
|
|
| 20 |
* @param user_data User data.
|
| 21 |
* @return True if the computation should proceed, false otherwise.
|
| 22 |
*/
|
| 23 |
-
boolean callback(
|
| 24 |
}
|
|
|
|
| 20 |
* @param user_data User data.
|
| 21 |
* @return True if the computation should proceed, false otherwise.
|
| 22 |
*/
|
| 23 |
+
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
|
| 24 |
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
package io.github.ggerganov.whispercpp.callbacks;
|
| 2 |
|
|
|
|
| 3 |
import com.sun.jna.Pointer;
|
| 4 |
-
import io.github.ggerganov.whispercpp.WhisperContext;
|
| 5 |
-
import io.github.ggerganov.whispercpp.model.WhisperState;
|
| 6 |
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
| 7 |
|
| 8 |
-
import javax.security.auth.callback.Callback;
|
| 9 |
-
|
| 10 |
/**
|
| 11 |
* Callback to filter logits.
|
| 12 |
* Can be used to modify the logits before sampling.
|
|
@@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
|
|
| 24 |
* @param logits The array of logits.
|
| 25 |
* @param user_data User data.
|
| 26 |
*/
|
| 27 |
-
void callback(
|
| 28 |
}
|
|
|
|
| 1 |
package io.github.ggerganov.whispercpp.callbacks;
|
| 2 |
|
| 3 |
+
import com.sun.jna.Callback;
|
| 4 |
import com.sun.jna.Pointer;
|
|
|
|
|
|
|
| 5 |
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
| 6 |
|
|
|
|
|
|
|
| 7 |
/**
|
| 8 |
* Callback to filter logits.
|
| 9 |
* Can be used to modify the logits before sampling.
|
|
|
|
| 21 |
* @param logits The array of logits.
|
| 22 |
* @param user_data User data.
|
| 23 |
*/
|
| 24 |
+
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
| 25 |
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java
CHANGED
|
@@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
|
|
| 20 |
* @param n_new The number of newly generated text segments.
|
| 21 |
* @param user_data User data.
|
| 22 |
*/
|
| 23 |
-
void callback(
|
| 24 |
}
|
|
|
|
| 20 |
* @param n_new The number of newly generated text segments.
|
| 21 |
* @param user_data User data.
|
| 22 |
*/
|
| 23 |
+
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
|
| 24 |
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
package io.github.ggerganov.whispercpp.callbacks;
|
| 2 |
|
|
|
|
| 3 |
import com.sun.jna.Pointer;
|
| 4 |
import io.github.ggerganov.whispercpp.WhisperContext;
|
| 5 |
import io.github.ggerganov.whispercpp.model.WhisperState;
|
| 6 |
|
| 7 |
-
import javax.security.auth.callback.Callback;
|
| 8 |
-
|
| 9 |
/**
|
| 10 |
* Callback for progress updates.
|
| 11 |
*/
|
|
@@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
|
|
| 19 |
* @param progress The progress value.
|
| 20 |
* @param user_data User data.
|
| 21 |
*/
|
| 22 |
-
void callback(
|
| 23 |
}
|
|
|
|
| 1 |
package io.github.ggerganov.whispercpp.callbacks;
|
| 2 |
|
| 3 |
+
import com.sun.jna.Callback;
|
| 4 |
import com.sun.jna.Pointer;
|
| 5 |
import io.github.ggerganov.whispercpp.WhisperContext;
|
| 6 |
import io.github.ggerganov.whispercpp.model.WhisperState;
|
| 7 |
|
|
|
|
|
|
|
| 8 |
/**
|
| 9 |
* Callback for progress updates.
|
| 10 |
*/
|
|
|
|
| 18 |
* @param progress The progress value.
|
| 19 |
* @param user_data User data.
|
| 20 |
*/
|
| 21 |
+
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
|
| 22 |
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package io.github.ggerganov.whispercpp.params;
|
| 2 |
+
|
| 3 |
+
import com.sun.jna.Structure;
|
| 4 |
+
|
| 5 |
+
import java.util.Arrays;
|
| 6 |
+
import java.util.List;
|
| 7 |
+
|
| 8 |
+
public class BeamSearchParams extends Structure {
|
| 9 |
+
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
|
| 10 |
+
public int beam_size;
|
| 11 |
+
|
| 12 |
+
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
|
| 13 |
+
public float patience;
|
| 14 |
+
|
| 15 |
+
@Override
|
| 16 |
+
protected List<String> getFieldOrder() {
|
| 17 |
+
return Arrays.asList("beam_size", "patience");
|
| 18 |
+
}
|
| 19 |
+
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package io.github.ggerganov.whispercpp.params;
|
| 2 |
+
|
| 3 |
+
import com.sun.jna.IntegerType;
|
| 4 |
+
|
| 5 |
+
import java.util.function.BooleanSupplier;
|
| 6 |
+
|
| 7 |
+
public class CBool extends IntegerType implements BooleanSupplier {
|
| 8 |
+
public static final int SIZE = 1;
|
| 9 |
+
public static final CBool FALSE = new CBool(0);
|
| 10 |
+
public static final CBool TRUE = new CBool(1);
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
public CBool() {
|
| 14 |
+
this(0);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
public CBool(long value) {
|
| 18 |
+
super(SIZE, value, true);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@Override
|
| 22 |
+
public boolean getAsBoolean() {
|
| 23 |
+
return intValue() == 1;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
@Override
|
| 27 |
+
public String toString() {
|
| 28 |
+
return intValue() == 1 ? "true" : "false";
|
| 29 |
+
}
|
| 30 |
+
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package io.github.ggerganov.whispercpp.params;
|
| 2 |
+
|
| 3 |
+
import com.sun.jna.Structure;
|
| 4 |
+
|
| 5 |
+
import java.util.Collections;
|
| 6 |
+
import java.util.List;
|
| 7 |
+
|
| 8 |
+
public class GreedyParams extends Structure {
|
| 9 |
+
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
|
| 10 |
+
public int best_of;
|
| 11 |
+
|
| 12 |
+
@Override
|
| 13 |
+
protected List<String> getFieldOrder() {
|
| 14 |
+
return Collections.singletonList("best_of");
|
| 15 |
+
}
|
| 16 |
+
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
package io.github.ggerganov.whispercpp.params;
|
| 2 |
|
| 3 |
-
import com.sun.jna
|
| 4 |
-
import com.sun.jna.Pointer;
|
| 5 |
-
import com.sun.jna.Structure;
|
| 6 |
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
| 7 |
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
| 8 |
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
| 9 |
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
/**
|
| 12 |
* Parameters for the whisper_full() function.
|
| 13 |
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
|
@@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
|
| 15 |
*/
|
| 16 |
public class WhisperFullParams extends Structure {
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
/** Sampling strategy for whisper_full() function. */
|
| 19 |
public int strategy;
|
| 20 |
|
| 21 |
-
/** Number of threads. */
|
| 22 |
public int n_threads;
|
| 23 |
|
| 24 |
-
/** Maximum tokens to use from past text as a prompt for the decoder. */
|
| 25 |
public int n_max_text_ctx;
|
| 26 |
|
| 27 |
-
/** Start offset in milliseconds. */
|
| 28 |
public int offset_ms;
|
| 29 |
|
| 30 |
-
/** Audio duration to process in milliseconds. */
|
| 31 |
public int duration_ms;
|
| 32 |
|
| 33 |
-
/** Translate flag. */
|
| 34 |
-
public
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
/**
|
| 37 |
-
public
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
/** Flag to
|
| 40 |
-
public
|
| 41 |
|
| 42 |
-
/** Flag to
|
| 43 |
-
public boolean
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
/** Flag to
|
| 46 |
-
public
|
| 47 |
|
| 48 |
-
/** Flag to
|
| 49 |
-
public boolean
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
/** Flag to print
|
| 52 |
-
public
|
| 53 |
|
| 54 |
-
/**
|
| 55 |
-
public boolean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
/**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
public float thold_pt;
|
| 59 |
|
| 60 |
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
| 61 |
public float thold_ptsum;
|
| 62 |
|
| 63 |
-
/** Maximum segment length in characters. */
|
| 64 |
public int max_len;
|
| 65 |
|
| 66 |
-
/** Flag to split on word rather than on token (when used with max_len). */
|
| 67 |
-
public
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
/** Maximum tokens per segment (0 = no limit)
|
| 70 |
public int max_tokens;
|
| 71 |
|
| 72 |
-
/** Flag to speed up the audio by 2x using Phase Vocoder. */
|
| 73 |
-
public
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
/** Overwrite the audio context size (0 = use default). */
|
| 76 |
public int audio_ctx;
|
|
@@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
|
|
| 79 |
* These are prepended to any existing text context from a previous call. */
|
| 80 |
public String initial_prompt;
|
| 81 |
|
| 82 |
-
/** Prompt tokens. */
|
| 83 |
public Pointer prompt_tokens;
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
/** Number of prompt tokens. */
|
| 86 |
public int prompt_n_tokens;
|
| 87 |
|
|
@@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
|
|
| 90 |
public String language;
|
| 91 |
|
| 92 |
/** Flag to indicate whether to detect language automatically. */
|
| 93 |
-
public
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
/** Flag to suppress blank tokens. */
|
| 98 |
-
public
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
/** Flag to suppress non-speech tokens. */
|
| 101 |
-
public boolean
|
|
|
|
|
|
|
| 102 |
|
| 103 |
/** Initial decoding temperature. */
|
| 104 |
public float temperature;
|
|
@@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
|
|
| 109 |
/** Length penalty. */
|
| 110 |
public float length_penalty;
|
| 111 |
|
| 112 |
-
|
| 113 |
|
| 114 |
/** Temperature increment. */
|
| 115 |
public float temperature_inc;
|
|
@@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
|
|
| 123 |
/** No speech threshold. */
|
| 124 |
public float no_speech_thold;
|
| 125 |
|
| 126 |
-
class GreedyParams extends Structure {
|
| 127 |
-
/** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
|
| 128 |
-
public int best_of;
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
/** Greedy decoding parameters. */
|
| 132 |
public GreedyParams greedy;
|
| 133 |
|
| 134 |
-
class BeamSearchParams extends Structure {
|
| 135 |
-
/** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
|
| 136 |
-
int beam_size;
|
| 137 |
-
|
| 138 |
-
/** ref: https://arxiv.org/pdf/2204.05424.pdf */
|
| 139 |
-
float patience;
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
/**
|
| 143 |
* Beam search decoding parameters.
|
| 144 |
*/
|
| 145 |
public BeamSearchParams beam_search;
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
/**
|
| 148 |
* Callback for every newly generated text segment.
|
|
|
|
| 149 |
*/
|
| 150 |
-
public
|
| 151 |
|
| 152 |
/**
|
| 153 |
* User data for the new_segment_callback.
|
|
@@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
|
|
| 156 |
|
| 157 |
/**
|
| 158 |
* Callback on each progress update.
|
|
|
|
| 159 |
*/
|
| 160 |
-
public
|
| 161 |
|
| 162 |
/**
|
| 163 |
* User data for the progress_callback.
|
|
@@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
|
|
| 166 |
|
| 167 |
/**
|
| 168 |
* Callback each time before the encoder starts.
|
|
|
|
| 169 |
*/
|
| 170 |
-
public
|
| 171 |
|
| 172 |
/**
|
| 173 |
* User data for the encoder_begin_callback.
|
|
@@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
|
|
| 176 |
|
| 177 |
/**
|
| 178 |
* Callback by each decoder to filter obtained logits.
|
|
|
|
| 179 |
*/
|
| 180 |
-
public
|
| 181 |
|
| 182 |
/**
|
| 183 |
* User data for the logits_filter_callback.
|
| 184 |
*/
|
| 185 |
public Pointer logits_filter_callback_user_data;
|
| 186 |
-
}
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
package io.github.ggerganov.whispercpp.params;
|
| 2 |
|
| 3 |
+
import com.sun.jna.*;
|
|
|
|
|
|
|
| 4 |
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
| 5 |
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
| 6 |
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
| 7 |
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
| 8 |
|
| 9 |
+
import java.util.Arrays;
|
| 10 |
+
import java.util.List;
|
| 11 |
+
|
| 12 |
/**
|
| 13 |
* Parameters for the whisper_full() function.
|
| 14 |
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
|
|
|
| 16 |
*/
|
| 17 |
public class WhisperFullParams extends Structure {
|
| 18 |
|
| 19 |
+
public WhisperFullParams(Pointer p) {
|
| 20 |
+
super(p);
|
| 21 |
+
// super(p, ALIGN_MSVC);
|
| 22 |
+
// super(p, ALIGN_GNUC);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
/** Sampling strategy for whisper_full() function. */
|
| 26 |
public int strategy;
|
| 27 |
|
| 28 |
+
/** Number of threads. (default = 4) */
|
| 29 |
public int n_threads;
|
| 30 |
|
| 31 |
+
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
|
| 32 |
public int n_max_text_ctx;
|
| 33 |
|
| 34 |
+
/** Start offset in milliseconds. (default = 0) */
|
| 35 |
public int offset_ms;
|
| 36 |
|
| 37 |
+
/** Audio duration to process in milliseconds. (default = 0) */
|
| 38 |
public int duration_ms;
|
| 39 |
|
| 40 |
+
/** Translate flag. (default = false) */
|
| 41 |
+
public CBool translate;
|
| 42 |
+
|
| 43 |
+
/** The compliment of translateMode() */
|
| 44 |
+
public void transcribeMode() {
|
| 45 |
+
translate = CBool.FALSE;
|
| 46 |
+
}
|
| 47 |
|
| 48 |
+
/** The compliment of transcribeMode() */
|
| 49 |
+
public void translateMode() {
|
| 50 |
+
translate = CBool.TRUE;
|
| 51 |
+
}
|
| 52 |
|
| 53 |
+
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
| 54 |
+
public CBool no_context;
|
| 55 |
|
| 56 |
+
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
| 57 |
+
public void enableContext(boolean enable) {
|
| 58 |
+
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
| 59 |
+
}
|
| 60 |
|
| 61 |
+
/** Flag to force single segment output (useful for streaming). (default = false) */
|
| 62 |
+
public CBool single_segment;
|
| 63 |
|
| 64 |
+
/** Flag to force single segment output (useful for streaming). (default = false) */
|
| 65 |
+
public void singleSegment(boolean single) {
|
| 66 |
+
single_segment = single ? CBool.TRUE : CBool.FALSE;
|
| 67 |
+
}
|
| 68 |
|
| 69 |
+
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
| 70 |
+
public CBool print_special;
|
| 71 |
|
| 72 |
+
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
| 73 |
+
public void printSpecial(boolean enable) {
|
| 74 |
+
print_special = enable ? CBool.TRUE : CBool.FALSE;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
/** Flag to print progress information. (default = true) */
|
| 78 |
+
public CBool print_progress;
|
| 79 |
+
|
| 80 |
+
/** Flag to print progress information. (default = true) */
|
| 81 |
+
public void printProgress(boolean enable) {
|
| 82 |
+
print_progress = enable ? CBool.TRUE : CBool.FALSE;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
| 86 |
+
public CBool print_realtime;
|
| 87 |
+
|
| 88 |
+
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
| 89 |
+
public void printRealtime(boolean enable) {
|
| 90 |
+
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
|
| 91 |
+
}
|
| 92 |
|
| 93 |
+
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
| 94 |
+
public CBool print_timestamps;
|
| 95 |
+
|
| 96 |
+
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
| 97 |
+
public void printTimestamps(boolean enable) {
|
| 98 |
+
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
| 102 |
+
public CBool token_timestamps;
|
| 103 |
+
|
| 104 |
+
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
| 105 |
+
public void tokenTimestamps(boolean enable) {
|
| 106 |
+
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
|
| 110 |
public float thold_pt;
|
| 111 |
|
| 112 |
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
| 113 |
public float thold_ptsum;
|
| 114 |
|
| 115 |
+
/** Maximum segment length in characters. (default = 0) */
|
| 116 |
public int max_len;
|
| 117 |
|
| 118 |
+
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
| 119 |
+
public CBool split_on_word;
|
| 120 |
+
|
| 121 |
+
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
| 122 |
+
public void splitOnWord(boolean enable) {
|
| 123 |
+
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
|
| 124 |
+
}
|
| 125 |
|
| 126 |
+
/** Maximum tokens per segment (0, default = no limit) */
|
| 127 |
public int max_tokens;
|
| 128 |
|
| 129 |
+
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
| 130 |
+
public CBool speed_up;
|
| 131 |
+
|
| 132 |
+
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
| 133 |
+
public void speedUp(boolean enable) {
|
| 134 |
+
speed_up = enable ? CBool.TRUE : CBool.FALSE;
|
| 135 |
+
}
|
| 136 |
|
| 137 |
/** Overwrite the audio context size (0 = use default). */
|
| 138 |
public int audio_ctx;
|
|
|
|
| 141 |
* These are prepended to any existing text context from a previous call. */
|
| 142 |
public String initial_prompt;
|
| 143 |
|
| 144 |
+
/** Prompt tokens. (int*) */
|
| 145 |
public Pointer prompt_tokens;
|
| 146 |
|
| 147 |
+
public void setPromptTokens(int[] tokens) {
|
| 148 |
+
Memory mem = new Memory(tokens.length * 4L);
|
| 149 |
+
mem.write(0, tokens, 0, tokens.length);
|
| 150 |
+
prompt_tokens = mem;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
/** Number of prompt tokens. */
|
| 154 |
public int prompt_n_tokens;
|
| 155 |
|
|
|
|
| 158 |
public String language;
|
| 159 |
|
| 160 |
/** Flag to indicate whether to detect language automatically. */
|
| 161 |
+
public CBool detect_language;
|
| 162 |
+
|
| 163 |
+
/** Flag to indicate whether to detect language automatically. */
|
| 164 |
+
public void detectLanguage(boolean enable) {
|
| 165 |
+
detect_language = enable ? CBool.TRUE : CBool.FALSE;
|
| 166 |
+
}
|
| 167 |
|
| 168 |
+
// Common decoding parameters.
|
| 169 |
|
| 170 |
/** Flag to suppress blank tokens. */
|
| 171 |
+
public CBool suppress_blank;
|
| 172 |
+
|
| 173 |
+
public void suppressBlanks(boolean enable) {
|
| 174 |
+
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
/** Flag to suppress non-speech tokens. */
|
| 178 |
+
public CBool suppress_non_speech_tokens;
|
| 179 |
|
| 180 |
/** Flag to suppress non-speech tokens. */
|
| 181 |
+
public void suppressNonSpeechTokens(boolean enable) {
|
| 182 |
+
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
|
| 183 |
+
}
|
| 184 |
|
| 185 |
/** Initial decoding temperature. */
|
| 186 |
public float temperature;
|
|
|
|
| 191 |
/** Length penalty. */
|
| 192 |
public float length_penalty;
|
| 193 |
|
| 194 |
+
// Fallback parameters.
|
| 195 |
|
| 196 |
/** Temperature increment. */
|
| 197 |
public float temperature_inc;
|
|
|
|
| 205 |
/** No speech threshold. */
|
| 206 |
public float no_speech_thold;
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
/** Greedy decoding parameters. */
|
| 209 |
public GreedyParams greedy;
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
/**
|
| 212 |
* Beam search decoding parameters.
|
| 213 |
*/
|
| 214 |
public BeamSearchParams beam_search;
|
| 215 |
|
| 216 |
+
public void setBestOf(int bestOf) {
|
| 217 |
+
if (greedy == null) {
|
| 218 |
+
greedy = new GreedyParams();
|
| 219 |
+
}
|
| 220 |
+
greedy.best_of = bestOf;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
public void setBeamSize(int beamSize) {
|
| 224 |
+
if (beam_search == null) {
|
| 225 |
+
beam_search = new BeamSearchParams();
|
| 226 |
+
}
|
| 227 |
+
beam_search.beam_size = beamSize;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
public void setBeamSizeAndPatience(int beamSize, float patience) {
|
| 231 |
+
if (beam_search == null) {
|
| 232 |
+
beam_search = new BeamSearchParams();
|
| 233 |
+
}
|
| 234 |
+
beam_search.beam_size = beamSize;
|
| 235 |
+
beam_search.patience = patience;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
/**
|
| 239 |
* Callback for every newly generated text segment.
|
| 240 |
+
* WhisperNewSegmentCallback
|
| 241 |
*/
|
| 242 |
+
public Pointer new_segment_callback;
|
| 243 |
|
| 244 |
/**
|
| 245 |
* User data for the new_segment_callback.
|
|
|
|
| 248 |
|
| 249 |
/**
|
| 250 |
* Callback on each progress update.
|
| 251 |
+
* WhisperProgressCallback
|
| 252 |
*/
|
| 253 |
+
public Pointer progress_callback;
|
| 254 |
|
| 255 |
/**
|
| 256 |
* User data for the progress_callback.
|
|
|
|
| 259 |
|
| 260 |
/**
|
| 261 |
* Callback each time before the encoder starts.
|
| 262 |
+
* WhisperEncoderBeginCallback
|
| 263 |
*/
|
| 264 |
+
public Pointer encoder_begin_callback;
|
| 265 |
|
| 266 |
/**
|
| 267 |
* User data for the encoder_begin_callback.
|
|
|
|
| 270 |
|
| 271 |
/**
|
| 272 |
* Callback by each decoder to filter obtained logits.
|
| 273 |
+
* WhisperLogitsFilterCallback
|
| 274 |
*/
|
| 275 |
+
public Pointer logits_filter_callback;
|
| 276 |
|
| 277 |
/**
|
| 278 |
* User data for the logits_filter_callback.
|
| 279 |
*/
|
| 280 |
public Pointer logits_filter_callback_user_data;
|
|
|
|
| 281 |
|
| 282 |
+
|
| 283 |
+
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
|
| 284 |
+
new_segment_callback = CallbackReference.getFunctionPointer(callback);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
public void setProgressCallback(WhisperProgressCallback callback) {
|
| 288 |
+
progress_callback = CallbackReference.getFunctionPointer(callback);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
|
| 292 |
+
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
|
| 296 |
+
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
@Override
|
| 300 |
+
protected List<String> getFieldOrder() {
|
| 301 |
+
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
| 302 |
+
"no_context", "single_segment",
|
| 303 |
+
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
| 304 |
+
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
| 305 |
+
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
| 306 |
+
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
| 307 |
+
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
| 308 |
+
"new_segment_callback", "new_segment_callback_user_data",
|
| 309 |
+
"progress_callback", "progress_callback_user_data",
|
| 310 |
+
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
| 311 |
+
"logits_filter_callback", "logits_filter_callback_user_data");
|
| 312 |
+
}
|
| 313 |
+
}
|
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
package io.github.ggerganov.whispercpp.params;
|
| 2 |
-
|
| 3 |
-
import com.sun.jna.Structure;
|
| 4 |
-
|
| 5 |
-
public class WhisperJavaParams extends Structure {
|
| 6 |
-
|
| 7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java
CHANGED
|
@@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
|
|
| 2 |
|
| 3 |
import static org.junit.jupiter.api.Assertions.*;
|
| 4 |
|
| 5 |
-
import io.github.ggerganov.whispercpp.params.
|
|
|
|
| 6 |
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
| 7 |
import org.junit.jupiter.api.BeforeAll;
|
| 8 |
import org.junit.jupiter.api.Test;
|
|
@@ -19,11 +20,11 @@ class WhisperCppTest {
|
|
| 19 |
static void init() throws FileNotFoundException {
|
| 20 |
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
| 21 |
// or you can provide the absolute path to the model file.
|
| 22 |
-
String modelName = "
|
| 23 |
try {
|
| 24 |
whisper.initContext(modelName);
|
| 25 |
-
whisper.
|
| 26 |
-
// whisper.
|
| 27 |
modelInitialised = true;
|
| 28 |
} catch (FileNotFoundException ex) {
|
| 29 |
System.out.println("Model " + modelName + " not found");
|
|
@@ -31,11 +32,30 @@ class WhisperCppTest {
|
|
| 31 |
}
|
| 32 |
|
| 33 |
@Test
|
| 34 |
-
void
|
| 35 |
// When
|
| 36 |
-
whisper.
|
| 37 |
|
| 38 |
-
// Then
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
@Test
|
|
@@ -52,6 +72,13 @@ class WhisperCppTest {
|
|
| 52 |
byte[] b = new byte[audioInputStream.available()];
|
| 53 |
float[] floats = new float[b.length / 2];
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
try {
|
| 56 |
audioInputStream.read(b);
|
| 57 |
|
|
@@ -61,13 +88,13 @@ class WhisperCppTest {
|
|
| 61 |
}
|
| 62 |
|
| 63 |
// When
|
| 64 |
-
String result = whisper.fullTranscribe(
|
| 65 |
|
| 66 |
// Then
|
| 67 |
-
System.
|
| 68 |
-
assertEquals("And so my fellow Americans
|
| 69 |
"ask what you can do for your country.",
|
| 70 |
-
result);
|
| 71 |
} finally {
|
| 72 |
audioInputStream.close();
|
| 73 |
}
|
|
|
|
| 2 |
|
| 3 |
import static org.junit.jupiter.api.Assertions.*;
|
| 4 |
|
| 5 |
+
import io.github.ggerganov.whispercpp.params.CBool;
|
| 6 |
+
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
| 7 |
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
| 8 |
import org.junit.jupiter.api.BeforeAll;
|
| 9 |
import org.junit.jupiter.api.Test;
|
|
|
|
| 20 |
static void init() throws FileNotFoundException {
|
| 21 |
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
| 22 |
// or you can provide the absolute path to the model file.
|
| 23 |
+
String modelName = "../../models/ggml-tiny.en.bin";
|
| 24 |
try {
|
| 25 |
whisper.initContext(modelName);
|
| 26 |
+
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
| 27 |
+
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
| 28 |
modelInitialised = true;
|
| 29 |
} catch (FileNotFoundException ex) {
|
| 30 |
System.out.println("Model " + modelName + " not found");
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
@Test
|
| 35 |
+
void testGetDefaultFullParams_BeamSearch() {
|
| 36 |
// When
|
| 37 |
+
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
| 38 |
|
| 39 |
+
// Then
|
| 40 |
+
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
|
| 41 |
+
assertNotEquals(0, params.n_threads);
|
| 42 |
+
assertEquals(16384, params.n_max_text_ctx);
|
| 43 |
+
assertFalse(params.translate);
|
| 44 |
+
assertEquals(0.01f, params.thold_pt);
|
| 45 |
+
assertEquals(2, params.beam_search.beam_size);
|
| 46 |
+
assertEquals(-1.0f, params.beam_search.patience);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
@Test
|
| 50 |
+
void testGetDefaultFullParams_Greedy() {
|
| 51 |
+
// When
|
| 52 |
+
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
| 53 |
+
|
| 54 |
+
// Then
|
| 55 |
+
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
| 56 |
+
assertNotEquals(0, params.n_threads);
|
| 57 |
+
assertEquals(16384, params.n_max_text_ctx);
|
| 58 |
+
assertEquals(2, params.greedy.best_of);
|
| 59 |
}
|
| 60 |
|
| 61 |
@Test
|
|
|
|
| 72 |
byte[] b = new byte[audioInputStream.available()];
|
| 73 |
float[] floats = new float[b.length / 2];
|
| 74 |
|
| 75 |
+
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
| 76 |
+
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
| 77 |
+
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
| 78 |
+
params.print_progress = CBool.FALSE;
|
| 79 |
+
// params.initial_prompt = "and so my fellow Americans um, like";
|
| 80 |
+
|
| 81 |
+
|
| 82 |
try {
|
| 83 |
audioInputStream.read(b);
|
| 84 |
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
// When
|
| 91 |
+
String result = whisper.fullTranscribe(params, floats);
|
| 92 |
|
| 93 |
// Then
|
| 94 |
+
System.err.println(result);
|
| 95 |
+
assertEquals("And so my fellow Americans ask not what your country can do for you " +
|
| 96 |
"ask what you can do for your country.",
|
| 97 |
+
result.replace(",", ""));
|
| 98 |
} finally {
|
| 99 |
audioInputStream.close();
|
| 100 |
}
|
whisper.cpp
CHANGED
|
@@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
|
|
| 2852 |
}
|
| 2853 |
}
|
| 2854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2855 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 2856 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
| 2857 |
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
@@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
|
|
| 3285 |
|
| 3286 |
////////////////////////////////////////////////////////////////////////////
|
| 3287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3288 |
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
| 3289 |
struct whisper_full_params result = {
|
| 3290 |
/*.strategy =*/ strategy,
|
|
|
|
| 2852 |
}
|
| 2853 |
}
|
| 2854 |
|
| 2855 |
+
void whisper_free_params(struct whisper_full_params * params) {
|
| 2856 |
+
if (params) {
|
| 2857 |
+
delete params;
|
| 2858 |
+
}
|
| 2859 |
+
}
|
| 2860 |
+
|
| 2861 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 2862 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
| 2863 |
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
|
|
| 3291 |
|
| 3292 |
////////////////////////////////////////////////////////////////////////////
|
| 3293 |
|
| 3294 |
+
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
|
| 3295 |
+
struct whisper_full_params params = whisper_full_default_params(strategy);
|
| 3296 |
+
|
| 3297 |
+
struct whisper_full_params* result = new whisper_full_params();
|
| 3298 |
+
*result = params;
|
| 3299 |
+
return result;
|
| 3300 |
+
}
|
| 3301 |
+
|
| 3302 |
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
| 3303 |
struct whisper_full_params result = {
|
| 3304 |
/*.strategy =*/ strategy,
|
whisper.h
CHANGED
|
@@ -113,6 +113,7 @@ extern "C" {
|
|
| 113 |
// Frees all allocated memory
|
| 114 |
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
| 115 |
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
|
|
|
| 116 |
|
| 117 |
// Convert RAW PCM audio to log mel spectrogram.
|
| 118 |
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
|
@@ -409,6 +410,8 @@ extern "C" {
|
|
| 409 |
void * logits_filter_callback_user_data;
|
| 410 |
};
|
| 411 |
|
|
|
|
|
|
|
| 412 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
| 413 |
|
| 414 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
|
|
|
| 113 |
// Frees all allocated memory
|
| 114 |
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
| 115 |
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
| 116 |
+
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
|
| 117 |
|
| 118 |
// Convert RAW PCM audio to log mel spectrogram.
|
| 119 |
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
|
|
|
| 410 |
void * logits_filter_callback_user_data;
|
| 411 |
};
|
| 412 |
|
| 413 |
+
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
| 414 |
+
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
|
| 415 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
| 416 |
|
| 417 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|