diff --git a/whisper_trt/model.py b/whisper_trt/model.py index 9a63803..ceb5176 100644 --- a/whisper_trt/model.py +++ b/whisper_trt/model.py @@ -34,6 +34,8 @@ import torch.nn as nn import torch2trt +import tensorrt + from dataclasses import asdict from .cache import get_cache_dir, make_cache_dir from .__version__ import __version__ @@ -150,7 +152,7 @@ def forward(self, mel: Tensor, tokens: Tensor): return self.decoder(tokens, self.encoder(mel)) @torch.no_grad() - def transcribe(self, audio: str | np.ndarray): + def transcribe(self, audio : str | np.ndarray): if isinstance(audio, str): audio = whisper.audio.load_audio(audio) @@ -184,6 +186,7 @@ class WhisperTRTBuilder: model: str fp16_mode: bool = True max_workspace_size: int = 1 << 30 + verbose: bool = False @classmethod @torch.no_grad() @@ -222,7 +225,8 @@ def build_text_decoder_engine(cls) -> torch2trt.TRTModule: input_names=["x", "xa", "mask"], output_names=["output"], max_workspace_size=cls.max_workspace_size, - fp16_mode=cls.fp16_mode + fp16_mode=cls.fp16_mode, + log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR, ) return engine @@ -266,7 +270,8 @@ def build_audio_encoder_engine(cls) -> torch2trt.TRTModule: input_names=["x", "positional_embedding"], output_names=["output"], max_workspace_size=cls.max_workspace_size, - fp16_mode=cls.fp16_mode + fp16_mode=cls.fp16_mode, + log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR, ) return engine @@ -298,8 +303,9 @@ def get_audio_encoder_extra_state(cls): @classmethod @torch.no_grad() - def build(cls, output_path: str): - + def build(cls, output_path: str, verbose: bool = False): + cls.verbose = verbose + checkpoint = { "whisper_trt_version": __version__, "dims": asdict(load_model(cls.model).dims), @@ -398,7 +404,7 @@ class SmallEnBuilder(EnBuilder): "small.en": SmallEnBuilder } -def load_trt_model(name: str, path: str | None = None, build: bool = True): +def load_trt_model(name: str, path: str | None = None, build: bool = True, verbose: bool = False): if name not in MODEL_BUILDERS: raise RuntimeError(f"Model '{name}' is not supported by WhisperTRT.") @@ -413,6 +419,6 @@ def load_trt_model(name: str, path: str | None = None, build: bool = True): if not build: raise RuntimeError(f"No model found at {path}. Please call load_trt_model with build=True.") else: - builder.build(path) + builder.build(path, verbose=verbose) return builder.load(path)