Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2t] support bitransformer decoder #2415

Merged
merged 6 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add bitransformer decoder, test=asr
  • Loading branch information
Zth9730 committed Sep 20, 2022
commit 1a56a6e42bccedee0285d8a22205d802878bab92
41 changes: 32 additions & 9 deletions paddlespeech/audio/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def has_tensor(val):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
Expand Down Expand Up @@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# _sos = paddle.to_tensor(
# [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# _eos = paddle.to_tensor(
# [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
# ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])

B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
Expand Down Expand Up @@ -190,3 +190,26 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)


def reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
pad_value: float=-1.0) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad
3 changes: 2 additions & 1 deletion paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def run(self):
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight)
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
Expand Down
9 changes: 6 additions & 3 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,12 @@ def setup_model(self):
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model_conf.output_dim = 5538
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved

model = U2Model.from_config(model_conf)

# params = model.state_dict()
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
# paddle.save(params, 'for_torch/test.pdparams')
# exit()
if self.parallel:
model = paddle.DataParallel(model)

Expand Down Expand Up @@ -350,7 +352,8 @@ def compute_metrics(self,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight)
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
decode_time = time.time() - start_time

for utt, target, result, rec_tids in zip(
Expand Down
152 changes: 137 additions & 15 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@

from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(self,
ctc: CTCDecoderBase,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
reverse_weight: float=0.0,
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
**kwargs):
Expand All @@ -82,6 +85,7 @@ def __init__(self,
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight

self.encoder = encoder
self.decoder = decoder
Expand Down Expand Up @@ -171,12 +175,21 @@ def _calc_att_loss(
self.ignore_id)
ys_in_lens = ys_pad_lens + 1

r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos,
self.ignore_id)
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
self.reverse_weight)

# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = paddle.to_tensor(0.0)
if self.reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - self.reverse_weight
) + r_loss_att * self.reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
Expand Down Expand Up @@ -359,6 +372,7 @@ def ctc_greedy_search(
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)

encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
Expand Down Expand Up @@ -500,7 +514,8 @@ def attention_rescoring(
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0,
simulate_streaming: bool=False, ) -> List[int]:
simulate_streaming: bool=False,
reverse_weight: float=0.0, ) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Expand All @@ -520,6 +535,9 @@ def attention_rescoring(
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.place
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
Expand All @@ -541,6 +559,7 @@ def attention_rescoring(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
Expand All @@ -550,13 +569,24 @@ def attention_rescoring(
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)

# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1,
self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()

# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
# conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy()

# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
Expand All @@ -567,6 +597,12 @@ def attention_rescoring(
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:
Expand Down Expand Up @@ -653,12 +689,24 @@ def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
"""
return self.ctc.log_softmax(xs)

@jit.to_static
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
torch.Tensor: decoder output
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False

@jit.to_static
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
encoder_out: paddle.Tensor,
reverse_weight: float=0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
Expand All @@ -676,11 +724,75 @@ def forward_attention_decoder(
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)

# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens = hyps_lens - 1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps = hyps[:, 1:]
# (num_hyps, max_hyps_len, vocab_size)

# Equal to:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码抽成一个函数吧

# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = paddle.max(r_hyps_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)

index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask

# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index),
[1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out

r_hyps = paddle_gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, self.eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
return decoder_out, r_decoder_out

@paddle.no_grad()
def decode(self,
Expand All @@ -692,7 +804,8 @@ def decode(self,
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
simulate_streaming: bool=False,
reverse_weight: float=0.0):
"""u2 decoding.

Args:
Expand Down Expand Up @@ -801,7 +914,6 @@ def __init__(self, configs: dict):
with DefaultInitializerContext(init_type):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
configs)

super().__init__(
vocab_size=vocab_size,
encoder=encoder,
Expand Down Expand Up @@ -851,10 +963,20 @@ def _init_from_config(cls, configs: dict):
raise ValueError(f"not support encoder type:{encoder_type}")

# decoder
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])

decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
elif decoder_type == 'bitransformer':
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
decoder = BiTransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
else:
raise ValueError(f"not support decoder type:{decoder_type}")
# ctc decoder and ctc loss
model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
Expand Down
Loading