17
17
f"\033 [0m" )
18
18
sys .exit (1 )
19
19
20
-
21
- def run_pytorch_fp32 (model_name , num_runs , timeout , ** kwargs ):
20
+ def run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = False ):
22
21
import os
23
22
import sys
24
23
import torch
@@ -32,6 +31,10 @@ def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs):
32
31
from speech_recognition .whisper .whisper .whisper .transcribe import transcribe
33
32
model = load_model (model_name )
34
33
model .eval ()
34
+ if use_torch_fp16 :
35
+ model = model .half ()
36
+ model ._encoder .half ()
37
+ model ._decoder .half ()
35
38
36
39
def single_pass_pytorch (_runner , _librispeech ):
37
40
array = _librispeech .get_input_array ()
@@ -40,15 +43,22 @@ def single_pass_pytorch(_runner, _librispeech):
40
43
_runner .run (batch_size * array .shape [0 ], audio )["text" ].lstrip ().replace ("." , "" ).upper ()
41
44
)
42
45
46
+ decode_options = {"fp16" : use_torch_fp16 }
47
+
43
48
def transcribe_wrapper (audio ):
44
- return transcribe (model , audio , no_speech_threshold = 1.0 , verbose = None )
49
+ return transcribe (model , audio , no_speech_threshold = 1.0 , verbose = None , ** decode_options )
45
50
46
51
runner = PyTorchRunnerV2 (transcribe_wrapper , throughput_only = True )
47
52
librispeech = LibriSpeech ()
48
53
print_warning_message ("Sampling rate Whisper operates at is 16,000 Hz, therefore throughput values below can be "
49
54
"divided by 16,000 to derive 'seconds of processed audio per second'" )
50
55
return run_model (single_pass_pytorch , runner , librispeech , batch_size , num_runs , timeout )
51
56
57
+ def run_pytorch_fp32 (model_name , num_runs , timeout ):
58
+ return run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = False )
59
+
60
+ def run_pytorch_fp16 (model_name , num_runs , timeout ):
61
+ return run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = True )
52
62
53
63
def run_pytorch_cuda (model_name , num_runs , timeout , ** kwargs ):
54
64
import os
0 commit comments