Skip to content

Commit e070915

Browse files
jhen0409ggerganov
authored andcommitted
whisper : add context param to disable gpu (ggml-org#1293)
* whisper : check state->ctx_metal not null * whisper : add whisper_context_params { use_gpu } * whisper : new API with params & deprecate old API * examples : use no-gpu param && whisper_init_from_file_with_params * whisper.objc : enable metal & disable on simulator * whisper.swiftui, metal : enable metal & support load default.metallib * whisper.android : use new API * bindings : use new API * addon.node : fix build & test * bindings : updata java binding * bindings : add missing whisper_context_default_params_by_ref WHISPER_API for java * metal : use SWIFTPM_MODULE_BUNDLE for GGML_SWIFT and reuse library load * metal : move bundle var into block * metal : use SWIFT_PACKAGE instead of GGML_SWIFT * style : minor updates --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent ba08fd6 commit e070915

File tree

29 files changed

+422
-171
lines changed

29 files changed

+422
-171
lines changed

bindings/go/whisper.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ var (
103103
func Whisper_init(path string) *Context {
104104
cPath := C.CString(path)
105105
defer C.free(unsafe.Pointer(cPath))
106-
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
106+
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
107107
return (*Context)(ctx)
108108
} else {
109109
return nil

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.sun.jna.ptr.PointerByReference;
55
import io.github.ggerganov.whispercpp.ggml.GgmlType;
66
import io.github.ggerganov.whispercpp.WhisperModel;
7+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
78

89
import java.util.List;
910

@@ -23,8 +24,9 @@ public class WhisperContext extends Structure {
2324
public PointerByReference vocab;
2425
public PointerByReference state;
2526

26-
/** populated by whisper_init_from_file() */
27+
/** populated by whisper_init_from_file_with_params() */
2728
String path_model;
29+
WhisperContextParams params;
2830

2931
// public static class ByReference extends WhisperContext implements Structure.ByReference {
3032
// }

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.sun.jna.Native;
44
import com.sun.jna.Pointer;
5+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
56
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
67
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
78

@@ -15,8 +16,9 @@
1516
public class WhisperCpp implements AutoCloseable {
1617
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
1718
private Pointer ctx = null;
18-
private Pointer greedyPointer = null;
19-
private Pointer beamPointer = null;
19+
private Pointer paramsPointer = null;
20+
private Pointer greedyParamsPointer = null;
21+
private Pointer beamParamsPointer = null;
2022

2123
public File modelDir() {
2224
String modelDirPath = System.getenv("XDG_CACHE_HOME");
@@ -31,6 +33,18 @@ public File modelDir() {
3133
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
3234
*/
3335
public void initContext(String modelPath) throws FileNotFoundException {
36+
initContextImpl(modelPath, getContextDefaultParams());
37+
}
38+
39+
/**
40+
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
41+
* @param params - params to use when initialising the context
42+
*/
43+
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
44+
initContextImpl(modelPath, params);
45+
}
46+
47+
private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
3448
if (ctx != null) {
3549
lib.whisper_free(ctx);
3650
}
@@ -43,13 +57,26 @@ public void initContext(String modelPath) throws FileNotFoundException {
4357
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
4458
}
4559

46-
ctx = lib.whisper_init_from_file(modelPath);
60+
ctx = lib.whisper_init_from_file_with_params(modelPath, params);
4761

4862
if (ctx == null) {
4963
throw new FileNotFoundException(modelPath);
5064
}
5165
}
5266

67+
/**
68+
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
69+
* Because this function allocates memory for the params, the caller must call either:
70+
* - call `whisper_free_context_params()`
71+
* - `Native.free(Pointer.nativeValue(pointer));`
72+
*/
73+
public WhisperContextParams getContextDefaultParams() {
74+
paramsPointer = lib.whisper_context_default_params_by_ref();
75+
WhisperContextParams params = new WhisperContextParams(paramsPointer);
76+
params.read();
77+
return params;
78+
}
79+
5380
/**
5481
* Provides default params which can be used with `whisper_full()` etc.
5582
* Because this function allocates memory for the params, the caller must call either:
@@ -63,15 +90,15 @@ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy)
6390

6491
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
6592
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
66-
if (greedyPointer == null) {
67-
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
93+
if (greedyParamsPointer == null) {
94+
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
6895
}
69-
pointer = greedyPointer;
96+
pointer = greedyParamsPointer;
7097
} else {
71-
if (beamPointer == null) {
72-
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
98+
if (beamParamsPointer == null) {
99+
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
73100
}
74-
pointer = beamPointer;
101+
pointer = beamParamsPointer;
75102
}
76103

77104
WhisperFullParams params = new WhisperFullParams(pointer);
@@ -93,13 +120,17 @@ private void freeContext() {
93120
}
94121

95122
private void freeParams() {
96-
if (greedyPointer != null) {
97-
Native.free(Pointer.nativeValue(greedyPointer));
98-
greedyPointer = null;
123+
if (paramsPointer != null) {
124+
Native.free(Pointer.nativeValue(paramsPointer));
125+
paramsPointer = null;
126+
}
127+
if (greedyParamsPointer != null) {
128+
Native.free(Pointer.nativeValue(greedyParamsPointer));
129+
greedyParamsPointer = null;
99130
}
100-
if (beamPointer != null) {
101-
Native.free(Pointer.nativeValue(beamPointer));
102-
beamPointer = null;
131+
if (beamParamsPointer != null) {
132+
Native.free(Pointer.nativeValue(beamParamsPointer));
133+
beamParamsPointer = null;
103134
}
104135
}
105136

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.sun.jna.Pointer;
66
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
77
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
8+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
89
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
910

1011
public interface WhisperCppJnaLibrary extends Library {
@@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library {
1314
String whisper_print_system_info();
1415

1516
/**
16-
* Allocate (almost) all memory needed for the model by loading from a file.
17+
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
1718
*
1819
* @param path_model Path to the model file
1920
* @return Whisper context on success, null on failure
2021
*/
2122
Pointer whisper_init_from_file(String path_model);
23+
24+
/**
25+
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
26+
* Because this function allocates memory for the params, the caller must call either:
27+
* - call `whisper_free_context_params()`
28+
* - `Native.free(Pointer.nativeValue(pointer));`
29+
*/
30+
Pointer whisper_context_default_params_by_ref();
31+
32+
void whisper_free_context_params(Pointer params);
33+
34+
/**
35+
* Allocate (almost) all memory needed for the model by loading from a file.
36+
*
37+
* @param path_model Path to the model file
38+
* @param params Pointer to whisper_context_params
39+
* @return Whisper context on success, null on failure
40+
*/
41+
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
2242

2343
/**
2444
* Allocate (almost) all memory needed for the model by loading from a buffer.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package io.github.ggerganov.whispercpp.params;
2+
3+
import com.sun.jna.*;
4+
5+
import java.util.Arrays;
6+
import java.util.List;
7+
8+
/**
9+
* Parameters for the whisper_init_from_file_with_params() function.
10+
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
11+
* whisper_context_default_params()
12+
*/
13+
public class WhisperContextParams extends Structure {
14+
15+
public WhisperContextParams(Pointer p) {
16+
super(p);
17+
}
18+
19+
/** Use GPU for inference Number (default = true) */
20+
public CBool use_gpu;
21+
22+
/** Use GPU for inference Number (default = true) */
23+
public void useGpu(boolean enable) {
24+
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
25+
}
26+
27+
@Override
28+
protected List<String> getFieldOrder() {
29+
return Arrays.asList("use_gpu");
30+
}
31+
}

bindings/javascript/emscripten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct whisper_context * g_context;
2020
EMSCRIPTEN_BINDINGS(whisper) {
2121
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
2222
if (g_context == nullptr) {
23-
g_context = whisper_init_from_file(path_model.c_str());
23+
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
2424
if (g_context != nullptr) {
2525
return true;
2626
} else {

bindings/ruby/ext/ruby_whisper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
8787
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
8888
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
8989
}
90-
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
90+
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
9191
if (rw->context == nullptr) {
9292
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
9393
}

examples/addon.node/__test__/whisper.spec.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const whisperParamsMock = {
1111
language: "en",
1212
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
1313
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
14+
use_gpu: true,
1415
};
1516

1617
describe("Run whisper.node", () => {

examples/addon.node/addon.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct whisper_params {
5151
bool print_colors = false;
5252
bool print_progress = false;
5353
bool no_timestamps = false;
54+
bool use_gpu = true;
5455

5556
std::string language = "en";
5657
std::string prompt;
@@ -191,7 +192,9 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result, N
191192

192193
// whisper init
193194

194-
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
195+
struct whisper_context_params cparams;
196+
cparams.use_gpu = params.use_gpu;
197+
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
195198

196199
if (ctx == nullptr) {
197200
ofprintf(stderr, "error: failed to initialize whisper context\n");
@@ -365,10 +368,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
365368
std::string language = whisper_params.Get("language").As<Napi::String>();
366369
std::string model = whisper_params.Get("model").As<Napi::String>();
367370
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
371+
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
368372

369373
params.language = language;
370374
params.model = model;
371375
params.fname_inp.emplace_back(input);
376+
params.use_gpu = use_gpu;
372377

373378
Napi::Function callback = info[1].As<Napi::Function>();
374379
Napi::ThreadSafeFunction segment_callback = Napi::ThreadSafeFunction::New(

examples/addon.node/index.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ const { promisify } = require("util");
99
const whisperParams = {
1010
language: "en",
1111
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
12-
fname_inp: "../../samples/jfk.wav"
12+
fname_inp: "../../samples/jfk.wav",
13+
use_gpu: true,
1314
};
1415

1516
const arguments = process.argv.slice(2);

0 commit comments

Comments
 (0)