Skip to content

Commit a2d3eec

Browse files
author
GrzegorzKarchNV
authored
Merge pull request NVIDIA#411 from rajeevsrao/master
Init CUDA state before loading TRT engines in Taco2 sample
2 parents 8674bb6 + 4f42950 commit a2d3eec

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

PyTorch/SpeechSynthesis/Tacotron2/trt/inference_trt.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,15 @@ def main():
292292
parser = parse_args(parser)
293293
args, _ = parser.parse_known_args()
294294

295+
# initialize CUDA state
296+
torch.cuda.init()
297+
295298
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
296299
encoder = load_engine(args.encoder, TRT_LOGGER)
297300
decoder_iter = load_engine(args.decoder, TRT_LOGGER)
298301
postnet = load_engine(args.postnet, TRT_LOGGER)
299302
waveglow = load_engine(args.waveglow, TRT_LOGGER)
300303

301-
302-
303304
if args.waveglow_ckpt != "":
304305
# setup denoiser using WaveGlow PyTorch checkpoint
305306
waveglow_ckpt = load_and_setup_model('WaveGlow', parser, args.waveglow_ckpt,
@@ -310,9 +311,6 @@ def main():
310311
del waveglow_ckpt
311312
torch.cuda.empty_cache()
312313

313-
314-
# initialize CUDA state
315-
torch.cuda.init()
316314
# create TRT contexts for each engine
317315
encoder_context = encoder.create_execution_context()
318316
decoder_context = decoder_iter.create_execution_context()

0 commit comments

Comments
 (0)