Skip to content

Commit

Permalink
fix type hints for Python 3.8 and add verbose logging
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed May 28, 2024
1 parent 25450d1 commit 97b2438
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions whisper_trt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

import torch.nn as nn
import torch2trt
import tensorrt

from dataclasses import asdict
from .cache import get_cache_dir, make_cache_dir
from .__version__ import __version__
Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(self,
dims: ModelDimensions,
encoder: AudioEncoderTRT,
decoder: TextDecoderTRT,
tokenizer: Tokenizer | None = None
tokenizer: Tokenizer = None
):
super().__init__()
self.dims = dims
Expand All @@ -150,7 +152,7 @@ def forward(self, mel: Tensor, tokens: Tensor):
return self.decoder(tokens, self.encoder(mel))

@torch.no_grad()
def transcribe(self, audio: str | np.ndarray):
def transcribe(self, audio): #: str | np.ndarray):

if isinstance(audio, str):
audio = whisper.audio.load_audio(audio)
Expand Down Expand Up @@ -184,6 +186,7 @@ class WhisperTRTBuilder:
model: str
fp16_mode: bool = True
max_workspace_size: int = 1 << 30
verbose: bool = False

@classmethod
@torch.no_grad()
Expand Down Expand Up @@ -222,7 +225,8 @@ def build_text_decoder_engine(cls) -> torch2trt.TRTModule:
input_names=["x", "xa", "mask"],
output_names=["output"],
max_workspace_size=cls.max_workspace_size,
fp16_mode=cls.fp16_mode
fp16_mode=cls.fp16_mode,
log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR,
)

return engine
Expand Down Expand Up @@ -266,7 +270,8 @@ def build_audio_encoder_engine(cls) -> torch2trt.TRTModule:
input_names=["x", "positional_embedding"],
output_names=["output"],
max_workspace_size=cls.max_workspace_size,
fp16_mode=cls.fp16_mode
fp16_mode=cls.fp16_mode,
log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR,
)

return engine
Expand Down Expand Up @@ -298,8 +303,9 @@ def get_audio_encoder_extra_state(cls):

@classmethod
@torch.no_grad()
def build(cls, output_path: str):

def build(cls, output_path: str, verbose: bool = False):
cls.verbose = verbose

checkpoint = {
"whisper_trt_version": __version__,
"dims": asdict(load_model(cls.model).dims),
Expand Down Expand Up @@ -398,7 +404,7 @@ class SmallEnBuilder(EnBuilder):
"small.en": SmallEnBuilder
}

def load_trt_model(name: str, path: str | None = None, build: bool = True):
def load_trt_model(name: str, path: str = None, build: bool = True, verbose: bool = False):

if name not in MODEL_BUILDERS:
raise RuntimeError(f"Model '{name}' is not supported by WhisperTRT.")
Expand All @@ -413,6 +419,6 @@ def load_trt_model(name: str, path: str | None = None, build: bool = True):
if not build:
raise RuntimeError(f"No model found at {path}. Please call load_trt_model with build=True.")
else:
builder.build(path)
builder.build(path, verbose=verbose)

return builder.load(path)

0 comments on commit 97b2438

Please sign in to comment.