Skip to content

Commit 2cdfc4e

Browse files
authored
whisper : add support for large v3 (ggml-org#1444)
* whisper : add support for large v3 * bench : fix build + fix go bindings * bench : fix n_mels * models : update readme
1 parent 9731110 commit 2cdfc4e

20 files changed

+70
-38
lines changed

Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,10 @@ samples:
417417
.PHONY: medium.en
418418
.PHONY: medium
419419
.PHONY: large-v1
420+
.PHONY: large-v2
420421
.PHONY: large
421422

422-
tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
423+
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large: main
423424
bash ./models/download-ggml-model.sh $@
424425
@echo ""
425426
@echo "==============================================="

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ make small
234234
make medium.en
235235
make medium
236236
make large-v1
237+
make large-v2
237238
make large
238239
```
239240

@@ -245,7 +246,7 @@ make large
245246
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
246247
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
247248
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
248-
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
249+
| large | 2.9 GB | ~3.3 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
249250

250251
## Quantization
251252

bindings/go/examples/go-model-download/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ const (
2424

2525
var (
2626
// The models which will be downloaded, if no model is specified as an argument
27-
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
27+
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large"}
2828
)
2929

3030
var (

bindings/go/whisper.go

-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ const (
8383
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
8484
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
8585
NumFFT = C.WHISPER_N_FFT
86-
NumMEL = C.WHISPER_N_MEL
8786
HopLength = C.WHISPER_HOP_LENGTH
8887
ChunkSize = C.WHISPER_CHUNK_SIZE
8988
)

examples/bench.wasm/emscripten.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ void bench_main(size_t index) {
2323

2424
fprintf(stderr, "%s: running benchmark with %d threads - please wait...\n", __func__, n_threads);
2525

26-
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
26+
const int n_mels = whisper_model_n_mels(ctx);
27+
28+
if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
2729
fprintf(stderr, "error: failed to set mel: %d\n", ret);
2830
return;
2931
}

examples/bench/bench.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ int whisper_bench_full(const whisper_params & params) {
7373
return 2;
7474
}
7575

76-
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
76+
const int n_mels = whisper_model_n_mels(ctx);
77+
78+
if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
7779
fprintf(stderr, "error: failed to set mel: %d\n", ret);
7880
return 3;
7981
}

examples/livestream.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ if [ -n "$3" ]; then
4848
fi
4949

5050
# Whisper models
51-
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
51+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
5252

5353
# list available models
5454
function list_models {

examples/twitch.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ help()
2121
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
2222
echo "options:"
2323
echo "-s Step in seconds (default is $step)."
24-
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large' (default is '$model')."
24+
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large' (default is '$model')."
2525
echo "-t Number of threads to use."
2626
echo "-h Print this help page."
2727
echo

extra/convert-all.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
3+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
44

55
for model in "${models[@]}"; do
66
python3 models/convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/

models/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ https://huggingface.co/ggerganov/whisper.cpp/tree/main
5050
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
5151
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
5252
| large-v1 | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
53-
| large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
53+
| large-v2 | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
54+
| large | 2.9 GB | ~4.7 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
5455

5556
## Model files for testing purposes
5657

models/convert-h5-to-coreml.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
7878
# Ported from models/convert-whisper-to-coreml.py
7979
if __name__ == "__main__":
8080
parser = argparse.ArgumentParser()
81-
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
81+
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
8282
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
8383
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
8484
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
8585
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
8686
args = parser.parse_args()
8787

88-
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
88+
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
8989
raise ValueError("Invalid model name")
9090

9191
pt_target_path = f"models/hf-{args.model_name}.pt"

models/convert-pt-to-ggml.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def bytes_to_unicode():
228228
# for backwards compatibility, also check for older hf_transformers format tokenizer files
229229
# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
230230
# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
231-
multilingual = hparams["n_vocab"] == 51865
231+
multilingual = hparams["n_vocab"] >= 51865
232232
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
233233
tokenizer_type = "tiktoken"
234234
if not tokenizer.is_file():

models/convert-whisper-to-coreml.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
194194
x = x.permute(0,2,3,1).squeeze(0)
195195

196196
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
197-
if self.token_embedding.weight.shape[0] == 51865:
197+
if self.token_embedding.weight.shape[0] >= 51865:
198198
# split in 11 chunks - 4715 each
199199
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
200200
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
@@ -296,13 +296,13 @@ def convert_decoder(hparams, model, quantize=False):
296296

297297
if __name__ == "__main__":
298298
parser = argparse.ArgumentParser()
299-
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
299+
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
300300
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
301301
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
302302
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
303303
args = parser.parse_args()
304304

305-
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
305+
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
306306
raise ValueError("Invalid model name")
307307

308308
whisper = load_model(args.model).cpu()

models/convert-whisper-to-openvino.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def convert_encoder(hparams, encoder, mname):
3838

3939
if __name__ == "__main__":
4040
parser = argparse.ArgumentParser()
41-
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
41+
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
4242
args = parser.parse_args()
4343

44-
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
44+
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
4545
raise ValueError("Invalid model name")
4646

4747
whisper = load_model(args.model).cpu()

models/download-coreml-model.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function get_script_path() {
1919
models_path="$(get_script_path)"
2020

2121
# Whisper models
22-
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
22+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
2323

2424
# list available models
2525
function list_models {

models/download-ggml-model.cmd

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ popd
88
set argc=0
99
for %%x in (%*) do set /A argc+=1
1010

11-
set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large
11+
set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large
1212

1313
if %argc% neq 1 (
1414
echo.
@@ -57,8 +57,8 @@ goto :eof
5757
:list_models
5858
echo.
5959
echo Available models:
60-
(for %%a in (%models%) do (
61-
echo %%a
60+
(for %%a in (%models%) do (
61+
echo %%a
6262
))
6363
echo.
6464
exit /b

models/download-ggml-model.sh

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ models=(
4141
"medium-q5_0"
4242
"medium.en-q5_0"
4343
"large-v1"
44+
"large-v2"
4445
"large"
4546
"large-q5_0"
4647
)

tests/run-tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
cd `dirname $0`
2020

2121
# Whisper models
22-
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
22+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
2323

2424
# list available models
2525
function list_models {

whisper.cpp

+40-14
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ enum e_model {
193193
MODEL_LARGE,
194194
};
195195

196+
static const std::map<e_model, std::string> g_model_name = {
197+
{ MODEL_UNKNOWN, "unknown" },
198+
{ MODEL_TINY, "tiny" },
199+
{ MODEL_BASE, "base" },
200+
{ MODEL_SMALL, "small" },
201+
{ MODEL_MEDIUM, "medium" },
202+
{ MODEL_LARGE, "large" },
203+
};
204+
196205
static const std::map<std::string, std::pair<int, std::string>> g_lang = {
197206
{ "en", { 0, "english", } },
198207
{ "zh", { 1, "chinese", } },
@@ -293,6 +302,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
293302
{ "ba", { 96, "bashkir", } },
294303
{ "jw", { 97, "javanese", } },
295304
{ "su", { 98, "sundanese", } },
305+
{ "yue", { 99, "cantonese", } },
296306
};
297307

298308
static const size_t MB = 1ull*1024*1024;
@@ -402,7 +412,11 @@ struct whisper_vocab {
402412
id token_beg = 50363; // begin timestamps
403413

404414
bool is_multilingual() const {
405-
return n_vocab == 51865;
415+
return n_vocab >= 51865;
416+
}
417+
418+
int num_languages() const {
419+
return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
406420
}
407421
};
408422

@@ -922,6 +936,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
922936

923937
assert(hparams.n_text_state == hparams.n_audio_state);
924938

939+
std::string mver = "";
940+
925941
if (hparams.n_audio_layer == 4) {
926942
model.type = e_model::MODEL_TINY;
927943
}
@@ -940,6 +956,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
940956

941957
if (hparams.n_audio_layer == 32) {
942958
model.type = e_model::MODEL_LARGE;
959+
960+
if (hparams.n_vocab == 51866) {
961+
mver = " v3";
962+
}
943963
}
944964

945965
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
@@ -968,7 +988,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
968988
log("%s: n_mels = %d\n", __func__, hparams.n_mels);
969989
log("%s: ftype = %d\n", __func__, model.hparams.ftype);
970990
log("%s: qntvr = %d\n", __func__, qntvr);
971-
log("%s: type = %d\n", __func__, model.type);
991+
log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
972992

973993
// print memory requirements
974994
{
@@ -1039,13 +1059,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
10391059
if (vocab.is_multilingual()) {
10401060
vocab.token_eot++;
10411061
vocab.token_sot++;
1042-
vocab.token_translate++;
1043-
vocab.token_transcribe++;
1044-
vocab.token_solm++;
1045-
vocab.token_prev++;
1046-
vocab.token_nosp++;
1047-
vocab.token_not++;
1048-
vocab.token_beg++;
1062+
1063+
// account for variable number of language tokens
1064+
const int dt = vocab.num_languages() - 98;
1065+
1066+
vocab.token_translate += dt;
1067+
vocab.token_transcribe += dt;
1068+
vocab.token_solm += dt;
1069+
vocab.token_prev += dt;
1070+
vocab.token_nosp += dt;
1071+
vocab.token_not += dt;
1072+
vocab.token_beg += dt;
10491073
}
10501074

10511075
if (n_vocab < model.hparams.n_vocab) {
@@ -1074,6 +1098,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
10741098
vocab.id_to_token[i] = word;
10751099
}
10761100
}
1101+
1102+
log("%s: n_langs = %d\n", __func__, vocab.num_languages());
10771103
}
10781104

10791105
size_t ctx_size = 0;
@@ -3281,7 +3307,7 @@ void whisper_free_params(struct whisper_full_params * params) {
32813307
}
32823308

32833309
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3284-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3310+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
32853311
log("%s: failed to compute mel spectrogram\n", __func__);
32863312
return -1;
32873313
}
@@ -3295,7 +3321,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
32953321

32963322
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
32973323
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3298-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3324+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
32993325
log("%s: failed to compute mel spectrogram\n", __func__);
33003326
return -1;
33013327
}
@@ -3318,13 +3344,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float *
33183344
// TODO
33193345

33203346
int whisper_set_mel_with_state(
3321-
struct whisper_context * /*ctx*/,
3347+
struct whisper_context * ctx,
33223348
struct whisper_state * state,
33233349
const float * data,
33243350
int n_len,
33253351
int n_mel) {
3326-
if (n_mel != WHISPER_N_MEL) {
3327-
log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
3352+
if (n_mel != ctx->model.filters.n_mel) {
3353+
log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
33283354
return -1;
33293355
}
33303356

whisper.h

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
#define WHISPER_SAMPLE_RATE 16000
3131
#define WHISPER_N_FFT 400
32-
#define WHISPER_N_MEL 80
3332
#define WHISPER_HOP_LENGTH 160
3433
#define WHISPER_CHUNK_SIZE 30
3534

0 commit comments

Comments
 (0)