Skip to content

Commit

Permalink
Merge pull request #2 from dusty-nv/main
Browse files Browse the repository at this point in the history
Add option for verbose TRT logging
  • Loading branch information
dusty-nv authored May 28, 2024
2 parents 25450d1 + 1892773 commit 9a93ee6
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions whisper_trt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

import torch.nn as nn
import torch2trt
import tensorrt

from dataclasses import asdict
from .cache import get_cache_dir, make_cache_dir
from .__version__ import __version__
Expand Down Expand Up @@ -150,7 +152,7 @@ def forward(self, mel: Tensor, tokens: Tensor):
return self.decoder(tokens, self.encoder(mel))

@torch.no_grad()
def transcribe(self, audio: str | np.ndarray):
def transcribe(self, audio : str | np.ndarray):

if isinstance(audio, str):
audio = whisper.audio.load_audio(audio)
Expand Down Expand Up @@ -184,6 +186,7 @@ class WhisperTRTBuilder:
model: str
fp16_mode: bool = True
max_workspace_size: int = 1 << 30
verbose: bool = False

@classmethod
@torch.no_grad()
Expand Down Expand Up @@ -222,7 +225,8 @@ def build_text_decoder_engine(cls) -> torch2trt.TRTModule:
input_names=["x", "xa", "mask"],
output_names=["output"],
max_workspace_size=cls.max_workspace_size,
fp16_mode=cls.fp16_mode
fp16_mode=cls.fp16_mode,
log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR,
)

return engine
Expand Down Expand Up @@ -266,7 +270,8 @@ def build_audio_encoder_engine(cls) -> torch2trt.TRTModule:
input_names=["x", "positional_embedding"],
output_names=["output"],
max_workspace_size=cls.max_workspace_size,
fp16_mode=cls.fp16_mode
fp16_mode=cls.fp16_mode,
log_level=tensorrt.Logger.VERBOSE if cls.verbose else tensorrt.Logger.ERROR,
)

return engine
Expand Down Expand Up @@ -298,8 +303,9 @@ def get_audio_encoder_extra_state(cls):

@classmethod
@torch.no_grad()
def build(cls, output_path: str):

def build(cls, output_path: str, verbose: bool = False):
cls.verbose = verbose

checkpoint = {
"whisper_trt_version": __version__,
"dims": asdict(load_model(cls.model).dims),
Expand Down Expand Up @@ -398,7 +404,7 @@ class SmallEnBuilder(EnBuilder):
"small.en": SmallEnBuilder
}

def load_trt_model(name: str, path: str | None = None, build: bool = True):
def load_trt_model(name: str, path: str | None = None, build: bool = True, verbose: bool = False):

if name not in MODEL_BUILDERS:
raise RuntimeError(f"Model '{name}' is not supported by WhisperTRT.")
Expand All @@ -413,6 +419,6 @@ def load_trt_model(name: str, path: str | None = None, build: bool = True):
if not build:
raise RuntimeError(f"No model found at {path}. Please call load_trt_model with build=True.")
else:
builder.build(path)
builder.build(path, verbose=verbose)

return builder.load(path)

0 comments on commit 9a93ee6

Please sign in to comment.