From 97b2438148bb5933d5a5d6327ddf0cb83e1bf553 Mon Sep 17 00:00:00 2001 From: Dustin Franklin Date: Tue, 28 May 2024 09:30:01 -0400 Subject: [PATCH 1/2] fix type hints for Python 3.8 and add verbose logging --- whisper_trt/model.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/whisper_trt/model.py b/whisper_trt/model.py index 9a63803..99ca29d 100644 --- a/whisper_trt/model.py +++ b/whisper_trt/model.py @@ -34,6 +34,8 @@ import torch.nn as nn import torch2trt +import tensorrt + from dataclasses import asdict from .cache import get_cache_dir, make_cache_dir from .__version__ import __version__ @@ -132,7 +134,7 @@ def __init__(self, dims: ModelDimensions, encoder: AudioEncoderTRT, decoder: TextDecoderTRT, - tokenizer: Tokenizer | None = None + tokenizer: Tokenizer = None ): super().__init__() self.dims = dims @@ -150,7 +152,7 @@ def forward(self, mel: Tensor, tokens: Tensor): return self.decoder(tokens, self.encoder(mel)) @torch.no_grad() - def transcribe(self, audio: str | np.ndarray): + def transcribe(self, audio): #: str | np.ndarray): if isinstance(audio, str): audio = whisper.audio.load_audio(audio) @@ -184,6 +186,7 @@ class WhisperTRTBuilder: model: str fp16_mode: bool = True max_workspace_size: int = 1 << 30 + verbose: bool = False @classmethod @torch.no_grad() @@ -222,7 +225,8 @@ def build_text_decoder_engine(cls) -> torch2trt.TRTModule: input_names=["x", "xa", "mask"], output_names=["output"], max_workspace_size=cls.max_workspace_size, - fp16_mode=cls.fp16_mode + fp16_mode=cls.fp16_mode, + log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR, ) return engine @@ -266,7 +270,8 @@ def build_audio_encoder_engine(cls) -> torch2trt.TRTModule: input_names=["x", "positional_embedding"], output_names=["output"], max_workspace_size=cls.max_workspace_size, - fp16_mode=cls.fp16_mode + fp16_mode=cls.fp16_mode, + log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR, ) return engine @@ -298,8 +303,9 @@ def get_audio_encoder_extra_state(cls): @classmethod @torch.no_grad() - def build(cls, output_path: str): - + def build(cls, output_path: str, verbose: bool = False): + cls.verbose = verbose + checkpoint = { "whisper_trt_version": __version__, "dims": asdict(load_model(cls.model).dims), @@ -398,7 +404,7 @@ class SmallEnBuilder(EnBuilder): "small.en": SmallEnBuilder } -def load_trt_model(name: str, path: str | None = None, build: bool = True): +def load_trt_model(name: str, path: str = None, build: bool = True, verbose: bool = False): if name not in MODEL_BUILDERS: raise RuntimeError(f"Model '{name}' is not supported by WhisperTRT.") @@ -413,6 +419,6 @@ def load_trt_model(name: str, path: str | None = None, build: bool = True): if not build: raise RuntimeError(f"No model found at {path}. Please call load_trt_model with build=True.") else: - builder.build(path) + builder.build(path, verbose=verbose) return builder.load(path) From 1892773e77852f3c0294675c8df42cbfc965cb94 Mon Sep 17 00:00:00 2001 From: Dustin Franklin Date: Tue, 28 May 2024 10:38:36 -0400 Subject: [PATCH 2/2] restore type hints --- whisper_trt/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper_trt/model.py b/whisper_trt/model.py index 99ca29d..ceb5176 100644 --- a/whisper_trt/model.py +++ b/whisper_trt/model.py @@ -134,7 +134,7 @@ def __init__(self, dims: ModelDimensions, encoder: AudioEncoderTRT, decoder: TextDecoderTRT, - tokenizer: Tokenizer = None + tokenizer: Tokenizer | None = None ): super().__init__() self.dims = dims @@ -152,7 +152,7 @@ def forward(self, mel: Tensor, tokens: Tensor): return self.decoder(tokens, self.encoder(mel)) @torch.no_grad() - def transcribe(self, audio): #: str | np.ndarray): + def transcribe(self, audio : str | np.ndarray): if isinstance(audio, str): audio = whisper.audio.load_audio(audio) @@ -404,7 +404,7 @@ class SmallEnBuilder(EnBuilder): "small.en": SmallEnBuilder } -def load_trt_model(name: str, path: str = None, build: bool = True, verbose: bool = False): +def load_trt_model(name: str, path: str | None = None, build: bool = True, verbose: bool = False): if name not in MODEL_BUILDERS: raise RuntimeError(f"Model '{name}' is not supported by WhisperTRT.")