Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : add context param for disable gpu #1293

Merged
merged 17 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bindings : use new API
  • Loading branch information
jhen0409 committed Oct 6, 2023
commit e0ebea2dfa4f263c8af5b366d9cdd4f714423fcd
2 changes: 1 addition & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
return (*Context)(ctx)
} else {
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
private Pointer paramPointer = null;
private Pointer greedyParamsPointer = null;
private Pointer beamParamsPointer = null;

public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
Expand All @@ -43,7 +44,8 @@ public void initContext(String modelPath) throws FileNotFoundException {
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}

ctx = lib.whisper_init_from_file(modelPath);
paramPointer = lib.whisper_context_default_params_by_ref();
ctx = lib.whisper_init_from_file_with_params(modelPath, paramPointer);

if (ctx == null) {
throw new FileNotFoundException(modelPath);
Expand All @@ -63,15 +65,15 @@ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy)

// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (greedyParamsPointer == null) {
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
pointer = greedyParamsPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (beamParamsPointer == null) {
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
pointer = beamParamsPointer;
}

WhisperFullParams params = new WhisperFullParams(pointer);
Expand All @@ -93,13 +95,17 @@ private void freeContext() {
}

private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
if (paramPointer != null) {
Native.free(Pointer.nativeValue(paramPointer));
paramPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
if (greedyParamsPointer != null) {
Native.free(Pointer.nativeValue(greedyParamsPointer));
greedyParamsPointer = null;
}
if (beamParamsPointer != null) {
Native.free(Pointer.nativeValue(beamParamsPointer));
beamParamsPointer = null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,31 @@ public interface WhisperCppJnaLibrary extends Library {
String whisper_print_system_info();

/**
* Allocate (almost) all memory needed for the model by loading from a file.
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);

/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
Pointer whisper_context_default_params_by_ref();

void whisper_free_context_params(Pointer params);

/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_with_params(String path_model, Pointer params);

/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
Expand Down
2 changes: 1 addition & 1 deletion bindings/javascript/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str());
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_context != nullptr) {
return true;
} else {
Expand Down
2 changes: 1 addition & 1 deletion bindings/ruby/ext/ruby_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
Expand Down
6 changes: 6 additions & 0 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3267,6 +3267,12 @@ void whisper_free(struct whisper_context * ctx) {
}
}

void whisper_free_context_params(struct whisper_context_params * params) {
if (params) {
delete params;
}
}

void whisper_free_params(struct whisper_full_params * params) {
if (params) {
delete params;
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ extern "C" {
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);

// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
Expand Down