Skip to content

Commit

Permalink
support bitransformer decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Zth9730 committed Sep 20, 2022
1 parent 455379b commit 0a95689
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config, args):
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)

self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
Expand Down Expand Up @@ -90,7 +90,7 @@ def run(self):
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight)
reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __init__(self, config, args):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config, 'reverse_weight', '0.0')
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)

def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
Expand Down
12 changes: 6 additions & 6 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,24 +689,24 @@ def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
"""
return self.ctc.log_softmax(xs)

@jit.to_static
# @jit.to_static
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
torch.Tensor: decoder output
paddle.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False

@jit.to_static
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor,
reverse_weight: float=0, ) -> paddle.Tensor:
reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
Expand Down Expand Up @@ -783,15 +783,15 @@ def paddle_gather(x, dim, index):
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1)
r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out

@paddle.no_grad()
Expand Down
5 changes: 2 additions & 3 deletions paddlespeech/s2t/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,8 @@ def forward_one_step(
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
Expand Down

0 comments on commit 0a95689

Please sign in to comment.