Skip to content

Commit d837526

Browse files
committed
Adding fp16 mode whisper
1 parent b5c6c73 commit d837526

File tree

1 file changed

+13
-3
lines changed
  • speech_recognition/whisper

1 file changed

+13
-3
lines changed

speech_recognition/whisper/run.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
f"\033[0m")
1818
sys.exit(1)
1919

20-
21-
def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs):
20+
def run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False):
2221
import os
2322
import sys
2423
import torch
@@ -32,6 +31,10 @@ def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs):
3231
from speech_recognition.whisper.whisper.whisper.transcribe import transcribe
3332
model = load_model(model_name)
3433
model.eval()
34+
if use_torch_fp16:
35+
model = model.half()
36+
model._encoder.half()
37+
model._decoder.half()
3538

3639
def single_pass_pytorch(_runner, _librispeech):
3740
array = _librispeech.get_input_array()
@@ -40,15 +43,22 @@ def single_pass_pytorch(_runner, _librispeech):
4043
_runner.run(batch_size * array.shape[0], audio)["text"].lstrip().replace(".", "").upper()
4144
)
4245

46+
decode_options = {"fp16": use_torch_fp16}
47+
4348
def transcribe_wrapper(audio):
44-
return transcribe(model, audio, no_speech_threshold=1.0, verbose=None)
49+
return transcribe(model, audio, no_speech_threshold=1.0, verbose=None, **decode_options)
4550

4651
runner = PyTorchRunnerV2(transcribe_wrapper, throughput_only=True)
4752
librispeech = LibriSpeech()
4853
print_warning_message("Sampling rate Whisper operates at is 16,000 Hz, therefore throughput values below can be "
4954
"divided by 16,000 to derive 'seconds of processed audio per second'")
5055
return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout)
5156

57+
def run_pytorch_fp32(model_name, num_runs, timeout):
58+
return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False)
59+
60+
def run_pytorch_fp16(model_name, num_runs, timeout):
61+
return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=True)
5262

5363
def run_pytorch_cuda(model_name, num_runs, timeout, **kwargs):
5464
import os

0 commit comments

Comments
 (0)