diff --git a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py index f794f92ff9..87dc93ffd3 100644 --- a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py +++ b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py @@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens): expected_tokens = ["|", "f", "|", "o", "a"] self.assertEqual(tokens, expected_tokens) + + def test_lm_lifecycle(self): + """Passing lm without assiging it to a vaiable won't cause runtime error + + https://github.com/pytorch/audio/issues/3218 + """ + from torchaudio.models.decoder import ctc_decoder + + from .ctc_decoder_utils import CustomZeroLM + + decoder = ctc_decoder( + lexicon=get_asset_path("decoder/lexicon.txt"), + tokens=get_asset_path("decoder/tokens.txt"), + lm=CustomZeroLM(), + ) + decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32)) diff --git a/torchaudio/models/decoder/_ctc_decoder.py b/torchaudio/models/decoder/_ctc_decoder.py index d9fa5165d8..33daa09ec9 100644 --- a/torchaudio/models/decoder/_ctc_decoder.py +++ b/torchaudio/models/decoder/_ctc_decoder.py @@ -269,6 +269,12 @@ def __init__( ) else: self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions) + # https://github.com/pytorch/audio/issues/3218 + # If lm is passed like rvalue reference, the lm object gets garbage collected, + # and later call to the lm fails. + # This ensures that lm object is not deleted as long as the decoder is alive. + # https://github.com/pybind/pybind11/discussions/4013 + self.lm = lm def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: idxs = (g[0] for g in it.groupby(idxs))