diff --git a/examples/wenetspeech/asr1/README.md b/examples/wenetspeech/asr1/README.md index c08b94e29b9..5a516f8ea27 100644 --- a/examples/wenetspeech/asr1/README.md +++ b/examples/wenetspeech/asr1/README.md @@ -12,3 +12,36 @@ show model.tar.gz ``` tar tf model.tar.gz ``` + +other way is: + +```bash +tar cvzf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz model.yaml conf/tuning/ conf/chunk_conformer.yaml conf/preprocess.yaml data/mean_std.json exp/chunk_conformer/checkpoints/ +``` + +## Export Static Model + +>> Need Paddle >= 2.4 + +>> `data/test_meeting/data.list` +>> {"input": [{"name": "input1", "shape": [3.2230625, 80], "feat": "/home/PaddleSpeech/dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0163.wav", "filetype": "sound"}], "output": [{"name": "target1", "shape": [9, 5538], "text": "\u697c\u5e02\u8c03\u63a7\u5c06\u53bb\u5411\u4f55\u65b9", "token": "\u697c \u5e02 \u8c03 \u63a7 \u5c06 \u53bb \u5411 \u4f55 \u65b9", "tokenid": "1891 1121 3502 1543 1018 477 528 163 1657"}], "utt": "BAC009S0764W0163", "utt2spk": "S0764"} + +>> Test Wav: +>> wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav +### U2 chunk conformer +>> UiDecoder +>> Make sure `reverse_weight` in config is `0.0` +>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz +``` +tar zxvf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz +./local/export.sh conf/chunk_conformer.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji +``` + +### U2++ chunk conformer +>> BiDecoder +>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.0.model.tar.gz +>> Make sure `reverse_weight` in config is not `0.0` + +``` +./local/export.sh conf/chunk_conformer_u2pp.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji +``` diff --git a/examples/wenetspeech/asr1/conf/chunk_conformer.yaml b/examples/wenetspeech/asr1/conf/chunk_conformer.yaml new file mode 100644 index 00000000000..d2f43d873e0 --- /dev/null +++ b/examples/wenetspeech/asr1/conf/chunk_conformer.yaml @@ -0,0 +1,101 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + use_cnn_module: True + cnn_module_kernel: 15 + activation_type: swish + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + reverse_weight: 0.0 # unidecoder + length_normalized_loss: false + init_type: 'kaiming_uniform' + +# https://yaml.org/type/float.html +########################################### +# Data # +########################################### +train_manifest: data/train_l/data.list +dev_manifest: data/dev/data.list +test_manifest: data/test_meeting/data.list + +########################################### +# Dataloader # +########################################### +use_streaming_data: True +unit_type: 'char' +vocab_filepath: data/lang_char/vocab.txt +preprocess_config: conf/preprocess.yaml +spm_model_prefix: '' +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +do_filter: True +maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced +minlen_in: 10 +minlen_out: 0 +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 0 +subsampling_factor: 1 +num_encs: 1 + + +########################################### +# Training # +########################################### +n_epoch: 26 +accum_grad: 32 +global_grad_clip: 5.0 +dist_sampler: True +log_interval: 1 +checkpoint: + kbest_n: 50 + latest_n: 5 +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 5000 + lr_decay: 1.0 diff --git a/examples/wenetspeech/asr1/conf/chunk_conformer_u2pp.yaml b/examples/wenetspeech/asr1/conf/chunk_conformer_u2pp.yaml new file mode 100644 index 00000000000..2bb2006b5b7 --- /dev/null +++ b/examples/wenetspeech/asr1/conf/chunk_conformer_u2pp.yaml @@ -0,0 +1,100 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + use_cnn_module: True + cnn_module_kernel: 15 + activation_type: swish + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 3 # the number of encoder blocks + r_num_blocks: 3 #only for bitransformer + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 # only for bitransformer decoder + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### +train_manifest: data/train_l/data.list +dev_manifest: data/dev/data.list +test_manifest: data/test_meeting/data.list + +########################################### +# Dataloader # +########################################### +use_stream_data: True +vocab_filepath: data/lang_char/vocab.txt +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +spm_model_prefix: '' +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +do_filter: True +maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced +minlen_in: 10 +minlen_out: 0 +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 0 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 150 +accum_grad: 8 +global_grad_clip: 5.0 +dist_sampler: False +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/wenetspeech/asr1/conf/preprocess.yaml b/examples/wenetspeech/asr1/conf/preprocess.yaml index f7f4c58d522..c7ccc522d52 100644 --- a/examples/wenetspeech/asr1/conf/preprocess.yaml +++ b/examples/wenetspeech/asr1/conf/preprocess.yaml @@ -5,7 +5,7 @@ process: n_mels: 80 n_shift: 160 win_length: 400 - dither: 0.1 + dither: 1.0 - type: cmvn_json cmvn_path: data/mean_std.json # these three processes are a.k.a. SpecAugument diff --git a/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml b/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml new file mode 100644 index 00000000000..6945ed6eb2a --- /dev/null +++ b/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml @@ -0,0 +1,12 @@ +beam_size: 10 +decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' +ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. +reverse_weight: 0.3 # reverse weight for attention rescoring decode mode. +decoding_chunk_size: 16 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. +num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. +simulate_streaming: True # simulate streaming inference. Defaults to False. +decode_batch_size: 128 +error_rate_type: cer diff --git a/examples/wenetspeech/asr1/conf/tuning/decode.yaml b/examples/wenetspeech/asr1/conf/tuning/decode.yaml index 6924bfa637a..4015e9836d9 100644 --- a/examples/wenetspeech/asr1/conf/tuning/decode.yaml +++ b/examples/wenetspeech/asr1/conf/tuning/decode.yaml @@ -1,11 +1,12 @@ -decode_batch_size: 128 -error_rate_type: cer -decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' beam_size: 10 +decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. +reverse_weight: 0.3 # reverse weight for attention rescoring decode mode. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. -simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file +simulate_streaming: False # simulate streaming inference. Defaults to False. +decode_batch_size: 128 +error_rate_type: cer diff --git a/examples/wenetspeech/asr1/local/export.sh b/examples/wenetspeech/asr1/local/export.sh index 6b646b46903..1f89afd6b3a 100755 --- a/examples/wenetspeech/asr1/local/export.sh +++ b/examples/wenetspeech/asr1/local/export.sh @@ -12,9 +12,14 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 + +# export can not using StreamdataDataloader, set use_stream_dta False +# u2: reverse_weight should be 0.0 +# u2pp: reverse_weight should be same with config file. e.g. 0.3 python3 -u ${BIN_DIR}/export.py \ --ngpu ${ngpu} \ --config ${config_path} \ +--opts use_stream_data False \ --checkpoint_path ${ckpt_path_prefix} \ --export_path ${jit_model_export_path} diff --git a/examples/wenetspeech/asr1/local/quant.sh b/examples/wenetspeech/asr1/local/quant.sh new file mode 100755 index 00000000000..9dfea9045cb --- /dev/null +++ b/examples/wenetspeech/asr1/local/quant.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +if [ $# != 4 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +decode_config_path=$2 +ckpt_prefix=$3 +audio_file=$4 + +mkdir -p data +wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/ +if [ $? -ne 0 ]; then + exit 1 +fi + +if [ ! -f ${audio_file} ]; then + echo "Plase input the right audio_file path" + exit 1 +fi + + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi + +# download language model +#bash local/download_lm_ch.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + +for type in attention_rescoring; do + echo "decoding ${type}" + batch_size=1 + output_dir=${ckpt_prefix} + mkdir -p ${output_dir} + python3 -u ${BIN_DIR}/quant.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${output_dir}/${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} \ + --audio_file ${audio_file} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done +exit 0 diff --git a/paddlespeech/audio/compliance/kaldi.py b/paddlespeech/audio/compliance/kaldi.py index 538be019619..eb92ec1f2f4 100644 --- a/paddlespeech/audio/compliance/kaldi.py +++ b/paddlespeech/audio/compliance/kaldi.py @@ -74,16 +74,16 @@ def _feature_window_function( window_size: int, blackman_coeff: float, dtype: int, ) -> Tensor: - if window_type == HANNING: + if window_type == "hann": return get_window('hann', window_size, fftbins=False, dtype=dtype) - elif window_type == HAMMING: + elif window_type == "hamming": return get_window('hamming', window_size, fftbins=False, dtype=dtype) - elif window_type == POVEY: + elif window_type == "povey": return get_window( 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85) - elif window_type == RECTANGULAR: + elif window_type == "rect": return paddle.ones([window_size], dtype=dtype) - elif window_type == BLACKMAN: + elif window_type == "blackman": a = 2 * math.pi / (window_size - 1) window_function = paddle.arange(window_size, dtype=dtype) return (blackman_coeff - 0.5 * paddle.cos(a * window_function) + @@ -216,7 +216,7 @@ def spectrogram(waveform: Tensor, sr: int=16000, snip_edges: bool=True, subtract_mean: bool=False, - window_type: str=POVEY) -> Tensor: + window_type: str="povey") -> Tensor: """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's. Args: @@ -236,7 +236,7 @@ def spectrogram(waveform: Tensor, snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. - window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. + window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey". Returns: Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames @@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int, ('Bad values in options: vtln-low {} and vtln-high {}, versus ' 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) - bin = paddle.arange(num_bins).unsqueeze(1) + bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1) + # left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) + # center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) + # right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) - center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) - right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) + center_mel = left_mel + mel_freq_delta + right_mel = center_mel + mel_freq_delta if vtln_warp_factor != 1.0: left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, @@ -373,7 +376,8 @@ def _get_mel_banks(num_bins: int, center_freqs = _inverse_mel_scale(center_mel) # (num_bins) # (1, num_fft_bins) - mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) + mel = _mel_scale(fft_bin_width * paddle.arange( + num_fft_bins, dtype=paddle.float32)).unsqueeze(0) # (num_bins, num_fft_bins) up_slope = (mel - left_mel) / (center_mel - left_mel) @@ -418,11 +422,11 @@ def fbank(waveform: Tensor, vtln_high: float=-500.0, vtln_low: float=100.0, vtln_warp: float=1.0, - window_type: str=POVEY) -> Tensor: + window_type: str="povey") -> Tensor: """Compute and return filter banks from a waveform. The output is identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape `(C, T)`. + waveform (Tensor): A waveform tensor with shape `(C, T)`. `C` is in the range [0,1]. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. channel (int, optional): Select the channel of waveform. Defaults to -1. dither (float, optional): Dithering constant . Defaults to 0.0. @@ -448,7 +452,7 @@ def fbank(waveform: Tensor, vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0. vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0. vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0. - window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. + window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey". Returns: Tensor: A filter banks tensor with shape `(m, n_mels)`. @@ -472,7 +476,8 @@ def fbank(waveform: Tensor, # (n_mels, padded_window_size // 2) mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) - mel_energies = mel_energies.astype(dtype) + # mel_energies = mel_energies.astype(dtype) + assert mel_energies.dtype == dtype # (n_mels, padded_window_size // 2 + 1) mel_energies = paddle.nn.functional.pad( @@ -537,7 +542,7 @@ def mfcc(waveform: Tensor, vtln_high: float=-500.0, vtln_low: float=100.0, vtln_warp: float=1.0, - window_type: str=POVEY) -> Tensor: + window_type: str="povey") -> Tensor: """Compute and return mel frequency cepstral coefficients from a waveform. The output is identical to Kaldi's. diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index e9008f17422..93883c94dc3 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -152,8 +152,8 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, # return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0]) B = ys_pad.shape[0] - _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos - _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos + _sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype) + _eos = paddle.full([B, 1], eos, dtype=ys_pad.dtype) ys_in = paddle.cat([_sos, ys_pad], dim=1) mask_pad = (ys_in == ignore_id) ys_in = ys_in.masked_fill(mask_pad, eos) @@ -279,8 +279,8 @@ def paddle_gather(x, dim, index): # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, 2, 2]]) - eos = paddle.full([1], eos, dtype=r_hyps.dtype) - r_hyps = paddle.where(seq_mask, r_hyps, eos) + _eos = paddle.full([1], eos, dtype=r_hyps.dtype) + r_hyps = paddle.where(seq_mask, r_hyps, _eos) # >>> r_hyps # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 5fe2e16b91b..6663bcf87be 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -22,7 +22,6 @@ from paddlespeech.s2t.utils.log import Log -#TODO(Hui Zhang): remove fluid import logger = Log(__name__).getlog() ########### hack logging ############# @@ -167,13 +166,17 @@ def broadcast_shape(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): + # will be nan when value is `inf`. + # mask = mask.astype(xs.dtype) + # return xs * (1.0 - mask) + mask * value + bshape = broadcast_shape(xs.shape, mask.shape) mask.stop_gradient = True - tmp = paddle.ones(shape=[len(bshape)], dtype='int32') - for index in range(len(bshape)): - tmp[index] = bshape[index] - mask = mask.broadcast_to(tmp) - trues = paddle.ones_like(xs) * value + # tmp = paddle.ones(shape=[len(bshape)], dtype='int32') + # for index in range(len(bshape)): + # tmp[index] = bshape[index] + mask = mask.broadcast_to(bshape) + trues = paddle.full_like(xs, fill_value=value) xs = paddle.where(mask, trues, xs) return xs diff --git a/paddlespeech/s2t/exps/u2/bin/quant.py b/paddlespeech/s2t/exps/u2/bin/quant.py new file mode 100644 index 00000000000..225bbf6dbb5 --- /dev/null +++ b/paddlespeech/s2t/exps/u2/bin/quant.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" +import os +import sys +from pathlib import Path + +import paddle +import soundfile +from paddleslim import PTQ +from yacs.config import CfgNode + +from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.models.u2 import U2Model +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import UpdateConfig +logger = Log(__name__).getlog() + + +class U2Infer(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + + self.preprocess_conf = config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) + self.text_feature = TextFeaturizer( + unit_type=config.unit_type, + vocab=config.vocab_filepath, + spm_model_prefix=config.spm_model_prefix) + + paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + + # model + model_conf = config + with UpdateConfig(model_conf): + model_conf.input_dim = config.feat_dim + model_conf.output_dim = self.text_feature.vocab_size + model = U2Model.from_config(model_conf) + self.model = model + self.model.eval() + self.ptq = PTQ() + self.model = self.ptq.quantize(model) + + # load model + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + def run(self): + check(args.audio_file) + + with paddle.no_grad(): + # read + audio, sample_rate = soundfile.read( + self.audio_file, dtype="int16", always_2d=True) + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + + # fbank + feat = self.preprocessing(audio, **self.preprocess_args) + logger.info(f"feat shape: {feat.shape}") + + ilen = paddle.to_tensor(feat.shape[0]) + xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) + decode_config = self.config.decode + logger.info(f"decode cfg: {decode_config}") + result_transcripts = self.model.decode( + xs, + ilen, + text_feature=self.text_feature, + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size, + ctc_weight=decode_config.ctc_weight, + decoding_chunk_size=decode_config.decoding_chunk_size, + num_decoding_left_chunks=decode_config.num_decoding_left_chunks, + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=decode_config.reverse_weight) + rsl = result_transcripts[0][0] + utt = Path(self.audio_file).name + logger.info(f"hyp: {utt} {rsl}") + # print(self.model) + # print(self.model.forward_encoder_chunk) + + logger.info("-------------start quant ----------------------") + batch_size = 1 + feat_dim = 80 + model_size = 512 + num_left_chunks = -1 + reverse_weight = 0.3 + logger.info( + f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}" + ) + + # ######################## self.model.forward_encoder_chunk ############ + # input_spec = [ + # # (T,), int16 + # paddle.static.InputSpec(shape=[None], dtype='int16'), + # ] + # self.model.forward_feature = paddle.jit.to_static( + # self.model.forward_feature, input_spec=input_spec) + + ######################### self.model.forward_encoder_chunk ############ + input_spec = [ + # xs, (B, T, D) + paddle.static.InputSpec( + shape=[batch_size, None, feat_dim], dtype='float32'), + # offset, int, but need be tensor + paddle.static.InputSpec(shape=[1], dtype='int32'), + # required_cache_size, int + num_left_chunks, + # att_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32'), + # cnn_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32') + ] + self.model.forward_encoder_chunk = paddle.jit.to_static( + self.model.forward_encoder_chunk, input_spec=input_spec) + + ######################### self.model.ctc_activation ######################## + input_spec = [ + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32') + ] + self.model.ctc_activation = paddle.jit.to_static( + self.model.ctc_activation, input_spec=input_spec) + + ######################### self.model.forward_attention_decoder ######################## + input_spec = [ + # hyps, (B, U) + paddle.static.InputSpec(shape=[None, None], dtype='int64'), + # hyps_lens, (B,) + paddle.static.InputSpec(shape=[None], dtype='int64'), + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight + ] + self.model.forward_attention_decoder = paddle.jit.to_static( + self.model.forward_attention_decoder, input_spec=input_spec) + ################################################################################ + + # jit save + logger.info(f"export save: {self.args.export_path}") + config = { + 'is_static': True, + 'combine_params': True, + 'skip_forward': True + } + self.ptq.save_quantized_model(self.model, self.args.export_path) + # paddle.jit.save( + # self.model, + # self.args.export_path, + # combine_params=True, + # skip_forward=True) + + +def check(audio_file): + if not os.path.isfile(audio_file): + print("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main(config, args): + U2Infer(config, args).run() + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + "--audio_file", type=str, help="path of the input audio file") + parser.add_argument( + "--export_path", + type=str, + default='export', + help="path of the input audio file") + args = parser.parse_args() + + config = CfgNode(new_allowed=True) + + if args.config: + config.merge_from_file(args.config) + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + main(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index f14d804f188..b13fd0d3f61 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -20,8 +20,6 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments -# TODO(hui zhang): dynamic load - def main_sp(config, args): exp = Tester(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 46925faedd2..2e067ab6b70 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -68,7 +68,6 @@ def run(self): # read audio, sample_rate = soundfile.read( self.audio_file, dtype="int16", always_2d=True) - audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") @@ -77,8 +76,9 @@ def run(self): logger.info(f"feat shape: {feat.shape}") ilen = paddle.to_tensor(feat.shape[0]) - xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) + xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) decode_config = self.config.decode + logger.info(f"decode cfg: {decode_config}") result_transcripts = self.model.decode( xs, ilen, @@ -88,7 +88,8 @@ def run(self): ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming) + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=decode_config.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index a6197d07330..d093821d856 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -350,7 +350,8 @@ def compute_metrics(self, ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming) + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=decode_config.reverse_weight) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( @@ -462,20 +463,120 @@ def load_inferspec(self): infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.clone(), self.args.checkpoint_path) + batch_size = 1 feat_dim = self.test_loader.feat_dim - input_spec = [ - paddle.static.InputSpec(shape=[1, None, feat_dim], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[1], - dtype='int64'), # audio_length, [B] - ] - return infer_model, input_spec + model_size = self.config.encoder_conf.output_size + num_left_chunks = -1 + logger.info( + f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}" + ) + + return infer_model, (batch_size, feat_dim, model_size, num_left_chunks) @paddle.no_grad() def export(self): infer_model, input_spec = self.load_inferspec() - assert isinstance(input_spec, list), type(input_spec) infer_model.eval() - static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) - logger.info(f"Export code: {static_model.forward.code}") - paddle.jit.save(static_model, self.args.export_path) + paddle.set_device('cpu') + + assert isinstance(input_spec, (list, tuple)), type(input_spec) + batch_size, feat_dim, model_size, num_left_chunks = input_spec + + ######################## infer_model.forward_encoder_chunk ############ + input_spec = [ + # (T,), int16 + paddle.static.InputSpec(shape=[None], dtype='int16'), + ] + infer_model.forward_feature = paddle.jit.to_static( + infer_model.forward_feature, input_spec=input_spec) + + ######################### infer_model.forward_encoder_chunk ############ + input_spec = [ + # xs, (B, T, D) + paddle.static.InputSpec( + shape=[batch_size, None, feat_dim], dtype='float32'), + # offset, int, but need be tensor + paddle.static.InputSpec(shape=[1], dtype='int32'), + # required_cache_size, int + num_left_chunks, + # att_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32'), + # cnn_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32') + ] + infer_model.forward_encoder_chunk = paddle.jit.to_static( + infer_model.forward_encoder_chunk, input_spec=input_spec) + + ######################### infer_model.ctc_activation ######################## + input_spec = [ + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32') + ] + infer_model.ctc_activation = paddle.jit.to_static( + infer_model.ctc_activation, input_spec=input_spec) + + ######################### infer_model.forward_attention_decoder ######################## + reverse_weight = 0.3 + input_spec = [ + # hyps, (B, U) + paddle.static.InputSpec(shape=[None, None], dtype='int64'), + # hyps_lens, (B,) + paddle.static.InputSpec(shape=[None], dtype='int64'), + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight + ] + infer_model.forward_attention_decoder = paddle.jit.to_static( + infer_model.forward_attention_decoder, input_spec=input_spec) + + # jit save + logger.info(f"export save: {self.args.export_path}") + paddle.jit.save( + infer_model, + self.args.export_path, + combine_params=True, + skip_forward=True) + + # test dy2static + def flatten(out): + if isinstance(out, paddle.Tensor): + return [out] + + flatten_out = [] + for var in out: + if isinstance(var, (list, tuple)): + flatten_out.extend(flatten(var)) + else: + flatten_out.append(var) + return flatten_out + + # forward_encoder_chunk dygraph + xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32') + offset = paddle.to_tensor([0], dtype='int32') + required_cache_size = num_left_chunks + att_cache = paddle.zeros([0, 0, 0, 0]) + cnn_cache = paddle.zeros([0, 0, 0, 0]) + xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk( + xs1, offset, required_cache_size, att_cache, cnn_cache) + + # load static model + from paddle.jit.layer import Layer + layer = Layer() + logger.info(f"load export model: {self.args.export_path}") + layer.load(self.args.export_path, paddle.CPUPlace()) + + # forward_encoder_chunk static + xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32') + offset = paddle.to_tensor([0], dtype='int32') + att_cache = paddle.zeros([0, 0, 0, 0]) + cnn_cache = paddle.zeros([0, 0, 0, 0]) + func = getattr(layer, 'forward_encoder_chunk') + xs_s, att_cache_s, cnn_cache_s = func(xs1, offset, att_cache, cnn_cache) + np.testing.assert_allclose(xs_d, xs_s, atol=1e-5) + np.testing.assert_allclose(att_cache_d, att_cache_s, atol=1e-4) + np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4) + # logger.info(f"forward_encoder_chunk output: {xs_s}") diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index 1d70a310347..c07c95bd57d 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -20,8 +20,6 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments -# TODO(hui zhang): dynamic load - def main_sp(config, args): exp = Tester(config, args) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 4fe51c151a3..544c1e8367e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -124,17 +124,15 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch loss_att = None if self.ctc_weight != 1.0: start = time.time() loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, - text, text_lengths) + text, text_lengths, + self.reverse_weight) decoder_time = time.time() - start #logger.debug(f"decoder time: {decoder_time}") @@ -155,12 +153,12 @@ def forward( loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att return loss, loss_att, loss_ctc - def _calc_att_loss( - self, - encoder_out: paddle.Tensor, - encoder_mask: paddle.Tensor, - ys_pad: paddle.Tensor, - ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + def _calc_att_loss(self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, + reverse_weight: float) -> Tuple[paddle.Tensor, float]: """Calc attention loss. Args: @@ -168,6 +166,7 @@ def _calc_att_loss( encoder_mask (paddle.Tensor): [B, 1, Tmax] ys_pad (paddle.Tensor): [B, Umax] ys_pad_lens (paddle.Tensor): [B] + reverse_weight (float): reverse decoder weight. Returns: Tuple[paddle.Tensor, float]: attention_loss, accuracy rate @@ -182,15 +181,14 @@ def _calc_att_loss( # 1. Forward decoder decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, - self.reverse_weight) + reverse_weight) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) r_loss_att = paddle.to_tensor(0.0) - if self.reverse_weight > 0.0: + if reverse_weight > 0.0: r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) - loss_att = loss_att * (1 - self.reverse_weight - ) + r_loss_att * self.reverse_weight + loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, @@ -291,8 +289,7 @@ def recognize( # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: + if end_flag.sum() == running_size: break # 2.1 Forward decoder step @@ -378,9 +375,7 @@ def ctc_greedy_search( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) maxlen = encoder_out.shape[1] - # (TODO Hui Zhang): bool no support reduce_sum - # encoder_out_lens = encoder_mask.squeeze(1).sum(1) - encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) @@ -514,7 +509,8 @@ def attention_rescoring(self, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, ctc_weight: float=0.0, - simulate_streaming: bool=False) -> List[int]: + simulate_streaming: bool=False, + reverse_weight: float=0.0) -> List[int]: """ Apply attention rescoring decoding, CTC prefix beam search is applied first to get nbest, then we resoring the nbest on attention decoder with corresponding encoder out @@ -529,12 +525,13 @@ def attention_rescoring(self, 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion + reverse_weight (float): reverse deocder weight. Returns: List[int]: Attention rescoring result """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 - if self.reverse_weight > 0.0: + if reverse_weight > 0.0: # decoder should be a bitransformer decoder if reverse_weight > 0.0 assert hasattr(self.decoder, 'right_decoder') device = speech.place @@ -558,28 +555,22 @@ def attention_rescoring(self, hyp_content, place=device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) - ori_hyps_pad = hyps_pad hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + logger.debug( + f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") - r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos, - self.eos) - decoder_out, r_decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, - self.reverse_weight) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - decoder_out = decoder_out.numpy() + # (beam_size, max_hyps_len, vocab_size) + decoder_out, r_decoder_out = self.forward_attention_decoder( + hyps_pad, hyps_lens, encoder_out, reverse_weight) + decoder_out = decoder_out.numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. - r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) r_decoder_out = r_decoder_out.numpy() # Only use decoder score for rescoring @@ -592,46 +583,68 @@ def attention_rescoring(self, score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] - if self.reverse_weight > 0: + + logger.debug( + f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" + ) + + if reverse_weight > 0: r_score = 0.0 for j, w in enumerate(hyp[0]): r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0])][self.eos] - score = score * (1 - self.reverse_weight - ) + r_score * self.reverse_weight + + logger.debug( + f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" + ) + + score = score * (1 - reverse_weight) + r_score * reverse_weight # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: best_score = score best_index = i + + logger.debug(f"result: {hyps[best_index]}") return hyps[best_index][0] - #@jit.to_static + @jit.to_static(property=True) def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - #@jit.to_static + @jit.to_static(property=True) def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - #@jit.to_static + @jit.to_static(property=True) def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - #@jit.to_static + @jit.to_static(property=True) def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ return self.eos - @jit.to_static + @jit.to_static(property=True) + def is_bidirectional_decoder(self) -> bool: + """ + Returns: + paddle.Tensor: decoder output + """ + if hasattr(self.decoder, 'right_decoder'): + return True + else: + return False + + # @jit.to_static def forward_encoder_chunk( self, xs: paddle.Tensor, @@ -681,28 +694,16 @@ def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: Args: xs (paddle.Tensor): encoder output, (B, T, D) Returns: - paddle.Tensor: activation before ctc + paddle.Tensor: activation before ctc. (B, Tmax, odim) """ return self.ctc.log_softmax(xs) # @jit.to_static - def is_bidirectional_decoder(self) -> bool: - """ - Returns: - paddle.Tensor: decoder output - """ - if hasattr(self.decoder, 'right_decoder'): - return True - else: - return False - - # @jit.to_static - def forward_attention_decoder( - self, - hyps: paddle.Tensor, - hyps_lens: paddle.Tensor, - encoder_out: paddle.Tensor, - reverse_weight: float=0.0, ) -> paddle.Tensor: + def forward_attention_decoder(self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, + reverse_weight: float=0.0) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: @@ -747,7 +748,8 @@ def decode(self, ctc_weight: float=0.0, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, - simulate_streaming: bool=False): + simulate_streaming: bool=False, + reverse_weight: float=0.0): """u2 decoding. Args: @@ -766,6 +768,7 @@ def decode(self, num_decoding_left_chunks (int, optional): number of left chunks for decoding. Defaults to -1. simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + reverse_weight (float, optional): reverse decoder weight, used by `attention_rescoring`. Raises: ValueError: when not support decoding_method. @@ -819,7 +822,8 @@ def decode(self, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, ctc_weight=ctc_weight, - simulate_streaming=simulate_streaming) + simulate_streaming=simulate_streaming, + reverse_weight=reverse_weight) hyps = [hyp] else: raise ValueError(f"Not support decoding method: {decoding_method}") @@ -980,6 +984,49 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) + from paddlespeech.s2t.modules.fbank import KaldiFbank + import yaml + import json + import numpy as np + + input_dim = configs['input_dim'] + process = configs['preprocess_config'] + with open(process, encoding="utf-8") as f: + conf = yaml.safe_load(f) + assert isinstance(conf, dict), type(self.conf) + + for idx, process in enumerate(conf['process']): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + + if process_type == 'fbank_kaldi': + opts.update({'n_mels': input_dim}) + opts['dither'] = 0.0 + self.fbank = KaldiFbank(**opts) + logger.info(f"{self.__class__.__name__} export: {self.fbank}") + if process_type == 'cmvn_json': + # align with paddlespeech.audio.transform.cmvn:GlobalCMVN + std_floor = 1.0e-20 + + cmvn = opts['cmvn_path'] + if isinstance(cmvn, dict): + cmvn_stats = cmvn + else: + with open(cmvn) as f: + cmvn_stats = json.load(f) + count = cmvn_stats['frame_num'] + mean = np.array(cmvn_stats['mean_stat']) / count + square_sums = np.array(cmvn_stats['var_stat']) + var = square_sums / count - mean**2 + std = np.maximum(np.sqrt(var), std_floor) + istd = 1.0 / std + self.global_cmvn = GlobalCMVN( + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) + logger.info( + f"{self.__class__.__name__} export: {self.global_cmvn}") + def forward(self, feats, feats_lengths, @@ -995,9 +1042,25 @@ def forward(self, Returns: List[List[int]]: best path result """ - return self.ctc_greedy_search( - feats, - feats_lengths, - decoding_chunk_size=decoding_chunk_size, - num_decoding_left_chunks=num_decoding_left_chunks, - simulate_streaming=simulate_streaming) + # dummy code for dy2st + # return self.ctc_greedy_search( + # feats, + # feats_lengths, + # decoding_chunk_size=decoding_chunk_size, + # num_decoding_left_chunks=num_decoding_left_chunks, + # simulate_streaming=simulate_streaming) + return feats, feats_lengths + + def forward_feature(self, x): + """feature pipeline. + + Args: + x (paddle.Tensor): waveform (T,). + + Return: + feat (paddle.Tensor): feature (T, D) + """ + x = paddle.cast(x, paddle.float32) + feat = self.fbank(x) + feat = self.global_cmvn(feat) + return feat diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index e8b61bc0d47..31defbbaf1b 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -111,10 +111,7 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. ST-decoder branch start = time.time() diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 2d236743a6d..128f87c07e9 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -19,6 +19,7 @@ import paddle from paddle import nn +from paddle.nn import functional as F from paddle.nn import initializer as I from paddlespeech.s2t.modules.align import Linear @@ -45,6 +46,7 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float): """ super().__init__() assert n_feat % n_head == 0 + self.n_feat = n_feat # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head @@ -54,6 +56,16 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float): self.linear_out = Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) + def _build_once(self, *args, **kwargs): + super()._build_once(*args, **kwargs) + # if self.self_att: + # self.linear_kv = Linear(self.n_feat, self.n_feat*2) + if not self.training: + self.weight = paddle.concat( + [self.linear_k.weight, self.linear_v.weight], axis=-1) + self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) + self._built = True + def forward_qkv(self, query: paddle.Tensor, key: paddle.Tensor, @@ -73,9 +85,16 @@ def forward_qkv(self, (#batch, n_head, time2, d_k). """ n_batch = query.shape[0] + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + if self.training: + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + else: + k, v = F.linear(key, self.weight, self.bias).view( + n_batch, -1, 2 * self.h, self.d_k).split( + 2, axis=2) + q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) @@ -108,10 +127,10 @@ def forward_attention( # When will `if mask.size(2) > 0` be False? # 1. onnx(16/-1, -1/-1, 16/0) # 2. jit (16/-1, -1/-1, 16/0, 16/4) - if paddle.shape(mask)[2] > 0: # time2 > 0 + if mask.shape[2] > 0: # time2 > 0 mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2) # for last chunk, time2 might be larger than scores.size(-1) - mask = mask[:, :, :, :paddle.shape(scores)[-1]] + mask = mask[:, :, :, :scores.shape[-1]] scores = scores.masked_fill(mask, -float('inf')) attn = paddle.softmax( scores, axis=-1).masked_fill(mask, @@ -179,7 +198,7 @@ def forward(self, # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True - if paddle.shape(cache)[0] > 0: + if cache.shape[0] > 0: # last dim `d_k * 2` for (key, val) key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) @@ -188,8 +207,9 @@ def forward(self, # non-trivial to calculate `next_cache_start` here. new_cache = paddle.concat((k, v), axis=-1) - scores = paddle.matmul(q, - k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + # scores = paddle.matmul(q, + # k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache @@ -270,7 +290,7 @@ def forward(self, and `head * d_k == size` """ q, k, v = self.forward_qkv(query, key, value) - q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) # when export onnx model, for 1st chunk, we feed # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) @@ -287,7 +307,7 @@ def forward(self, # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True - if paddle.shape(cache)[0] > 0: + if cache.shape[0] > 0: # last dim `d_k * 2` for (key, val) key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) @@ -301,19 +321,23 @@ def forward(self, p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) # compute attention score # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) # compute matrix b and matrix d # (batch, head, time1, time2) - matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) # Remove rel_shift since it is useless in speech recognition, # and it requires special attention for streaming. # matrix_bd = self.rel_shift(matrix_bd) diff --git a/paddlespeech/s2t/modules/cmvn.py b/paddlespeech/s2t/modules/cmvn.py index 67f71b6678e..6a8c1660cf6 100644 --- a/paddlespeech/s2t/modules/cmvn.py +++ b/paddlespeech/s2t/modules/cmvn.py @@ -40,6 +40,13 @@ def __init__(self, self.register_buffer("mean", mean) self.register_buffer("istd", istd) + def __repr__(self): + return ("{name}(mean={mean}, istd={istd}, norm_var={norm_var})".format( + name=self.__class__.__name__, + mean=self.mean, + istd=self.istd, + norm_var=self.norm_var)) + def forward(self, x: paddle.Tensor): """ Args: diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index be6056546c4..09d903eee34 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -127,11 +127,11 @@ def forward( x = x.transpose([0, 2, 1]) # [B, C, T] # mask batch padding - if paddle.shape(mask_pad)[2] > 0: # time > 0 + if mask_pad.shape[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) if self.lorder > 0: - if paddle.shape(cache)[2] == 0: # cache_t == 0 + if cache.shape[2] == 0: # cache_t == 0 x = nn.functional.pad( x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') else: @@ -161,7 +161,7 @@ def forward( x = self.pointwise_conv2(x) # mask batch padding - if paddle.shape(mask_pad)[2] > 0: # time > 0 + if mask_pad.shape[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) x = x.transpose([0, 2, 1]) # [B, T, C] diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 3b1a7f23d60..4ddf057b661 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -140,9 +140,7 @@ def forward( # m: (1, L, L) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) # tgt_mask: (B, L, L) - # TODO(Hui Zhang): not support & for tensor - # tgt_mask = tgt_mask & m - tgt_mask = tgt_mask.logical_and(m) + tgt_mask = tgt_mask & m x, _ = self.embed(tgt) for layer in self.decoders: @@ -153,9 +151,7 @@ def forward( if self.use_output_layer: x = self.output_layer(x) - # TODO(Hui Zhang): reduce_sum not support bool type - # olens = tgt_mask.sum(1) - olens = tgt_mask.astype(paddle.int).sum(1) + olens = tgt_mask.sum(1) return x, paddle.to_tensor(0.0), olens def forward_one_step( @@ -247,7 +243,7 @@ def batch_score(self, ] # batch decoding - ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L) + ys_mask = subsequent_mask(ys.shape[-1]).unsqueeze(0) # (B,L,L) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) logp, states = self.forward_one_step( xs, xs_mask, ys, ys_mask, cache=batch_state) @@ -343,7 +339,7 @@ def forward( """ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens) - r_x = paddle.to_tensor(0.0) + r_x = paddle.zeros([1]) if reverse_weight > 0.0: r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, ys_in_lens) diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index 37b124e8499..cb7261107c2 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -114,10 +114,7 @@ def forward( ], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] - # TODO(Hui Zhang): slice not support bool type - # tgt_q_mask = tgt_mask[:, -1:, :] - tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast( - paddle.bool) + tgt_q_mask = tgt_mask[:, -1:, :] if self.concat_after: tgt_concat = paddle.cat( diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 3aeebd29b39..f41a7b5d4c3 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -89,7 +89,7 @@ def __init__(self, self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.dropout = nn.Dropout(p=dropout_rate) - self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D] + self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] position = paddle.arange( 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] @@ -97,9 +97,8 @@ def __init__(self, paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -(math.log(10000.0) / self.d_model)) - self.pe[:, 0::2] = paddle.sin(position * div_term) - self.pe[:, 1::2] = paddle.cos(position * div_term) - self.pe = self.pe.unsqueeze(0) #[1, T, D] + self.pe[:, :, 0::2] = paddle.sin(position * div_term) + self.pe[:, :, 1::2] = paddle.cos(position * div_term) def forward(self, x: paddle.Tensor, offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -111,12 +110,10 @@ def forward(self, x: paddle.Tensor, paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ - T = x.shape[1] assert offset + x.shape[ 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) - #TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor - pos_emb = self.pe[:, offset:offset + T] + pos_emb = self.pe[:, offset:offset + x.shape[1]] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -165,6 +162,5 @@ def forward(self, x: paddle.Tensor, 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) x = x * self.xscale - #TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 2f4ad1b2985..fd7bd7b9a1d 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -164,12 +164,8 @@ def forward( if self.global_cmvn is not None: xs = self.global_cmvn(xs) - #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor - xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) - #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor - masks = masks.astype(paddle.bool) - #TODO(Hui Zhang): mask_pad = ~masks - mask_pad = masks.logical_not() + xs, pos_emb, masks = self.embed(xs, masks, offset=0) + mask_pad = ~masks chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, @@ -215,11 +211,8 @@ def forward_chunk( same shape as the original cnn_cache """ assert xs.shape[0] == 1 # batch size must be one - # tmp_masks is just for interface compatibility - # TODO(Hui Zhang): stride_slice not support bool tensor - # tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool) - tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) - tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] + # tmp_masks is just for interface compatibility, [B=1, C=1, T] + tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool) if self.global_cmvn is not None: xs = self.global_cmvn(xs) @@ -228,9 +221,8 @@ def forward_chunk( xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) # after embed, xs=(B=1, chunk_size, hidden-dim) - elayers = paddle.shape(att_cache)[0] - cache_t1 = paddle.shape(att_cache)[2] - chunk_size = paddle.shape(xs)[1] + elayers, _, cache_t1, _ = att_cache.shape + chunk_size = xs.shape[1] attention_key_size = cache_t1 + chunk_size # only used when using `RelPositionMultiHeadedAttention` @@ -249,25 +241,30 @@ def forward_chunk( for i, layer in enumerate(self.encoders): # att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) + + # WARNING: eliminate if-else cond op in graph + # tensor zeros([0,0,0,0]) support [i:i+1] slice, will return zeros([0,0,0,0]) tensor + # raw code as below: + # att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, + # cnn_cache=cnn_cache[i:i+1] if cnn_cache.shape[0] > 0 else cnn_cache, xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i:i + 1] - if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) + att_cache=att_cache[i:i + 1], + cnn_cache=cnn_cache[i:i + 1], ) # new_att_cache = (1, head, attention_key_size, d_k*2) # new_cnn_cache = (B=1, hidden-dim, cache_t2) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim + r_cnn_cache.append(new_cnn_cache) # add elayer dim if self.normalize_before: xs = self.after_norm(xs) # r_att_cache (elayers, head, T, d_k*2) - # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2) + # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2) r_att_cache = paddle.concat(r_att_cache, axis=0) - r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) + r_cnn_cache = paddle.stack(r_cnn_cache, axis=0) return xs, r_att_cache, r_cnn_cache def forward_chunk_by_chunk( @@ -397,11 +394,7 @@ def forward_one_step( if self.global_cmvn is not None: xs = self.global_cmvn(xs) - #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor - xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) - #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor - masks = masks.astype(paddle.bool) - + xs, pos_emb, masks = self.embed(xs, masks, offset=0) if cache is None: cache = [None for _ in range(len(self.encoders))] new_cache = [] diff --git a/paddlespeech/s2t/modules/fbank.py b/paddlespeech/s2t/modules/fbank.py new file mode 100644 index 00000000000..8d76a472775 --- /dev/null +++ b/paddlespeech/s2t/modules/fbank.py @@ -0,0 +1,72 @@ +import paddle +from paddle import nn + +from paddlespeech.audio.compliance import kaldi +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['KaldiFbank'] + + +class KaldiFbank(nn.Layer): + def __init__( + self, + fs=16000, + n_mels=80, + n_shift=160, # unit:sample, 10ms + win_length=400, # unit:sample, 25ms + energy_floor=0.0, + dither=0.0): + """ + Args: + fs (int): sample rate of the audio + n_mels (int): number of mel filter banks + n_shift (int): number of points in a frame shift + win_length (int): number of points in a frame windows + energy_floor (float): Floor on energy in Spectrogram computation (absolute) + dither (float): Dithering constant. Default 0.0 + """ + super().__init__() + self.fs = fs + self.n_mels = n_mels + num_point_ms = fs / 1000 + self.n_frame_length = win_length / num_point_ms + self.n_frame_shift = n_shift / num_point_ms + self.energy_floor = energy_floor + self.dither = dither + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, " + "n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, " + "dither={dither}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_frame_shift=self.n_frame_shift, + n_frame_length=self.n_frame_length, + dither=self.dither, )) + + def forward(self, x: paddle.Tensor): + """ + Args: + x (paddle.Tensor): shape (Ti). + Not support: [Time, Channel] and Batch mode. + + Returns: + paddle.Tensor: (T, D) + """ + assert x.ndim == 1 + + feat = kaldi.fbank( + x.unsqueeze(0), # append channel dim, (C, Ti) + n_mels=self.n_mels, + frame_length=self.n_frame_length, + frame_shift=self.n_frame_shift, + dither=self.dither, + energy_floor=self.energy_floor, + sr=self.fs) + + assert feat.ndim == 2 # (T,D) + return feat diff --git a/paddlespeech/s2t/modules/loss.py b/paddlespeech/s2t/modules/loss.py index 884fb70c10f..afd5201aa8e 100644 --- a/paddlespeech/s2t/modules/loss.py +++ b/paddlespeech/s2t/modules/loss.py @@ -85,7 +85,7 @@ def forward(self, logits, ys_pad, hlens, ys_lens): Returns: [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. """ - B = paddle.shape(logits)[0] + B = logits.shape[0] # warp-ctc need logits, and do softmax on logits by itself # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) @@ -158,7 +158,7 @@ def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor: Returns: loss (paddle.Tensor) : The KL loss, scalar float value """ - B, T, D = paddle.shape(x) + B, T, D = x.shape assert D == self.size x = x.reshape((-1, self.size)) target = target.reshape([-1]) diff --git a/paddlespeech/s2t/modules/mask.py b/paddlespeech/s2t/modules/mask.py index 1f66c015acb..65619eb9076 100644 --- a/paddlespeech/s2t/modules/mask.py +++ b/paddlespeech/s2t/modules/mask.py @@ -109,12 +109,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: [1, 1, 1]] """ ret = paddle.ones([size, size], dtype=paddle.bool) - #TODO(Hui Zhang): tril not support bool - #return paddle.tril(ret) - ret = ret.astype(paddle.float) - ret = paddle.tril(ret) - ret = ret.astype(paddle.bool) - return ret + return paddle.tril(ret) def subsequent_chunk_mask( diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 88451ddd77f..782a437ee85 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -139,8 +139,8 @@ def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 """ x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) - b, c, t, f = paddle.shape(x) - x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) + b, c, t, f = x.shape + x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] @@ -192,8 +192,8 @@ def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) - b, c, t, f = paddle.shape(x) - x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) + b, c, t, f = x.shape + x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] @@ -245,6 +245,7 @@ def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) - x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) + b, c, t, f = x.shape + x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index 422d4f82a9a..3ac102f3c76 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -184,13 +184,8 @@ def th_accuracy(pad_outputs: paddle.Tensor, pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], pad_outputs.shape[1]).argmax(2) mask = pad_targets != ignore_label - #TODO(Hui Zhang): sum not support bool type - # numerator = paddle.sum( - # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = ( + + numerator = paddle.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = paddle.sum(numerator.type_as(pad_targets)) - #TODO(Hui Zhang): sum not support bool type - # denominator = paddle.sum(mask) - denominator = paddle.sum(mask.type_as(pad_targets)) + denominator = paddle.sum(mask) return float(numerator) / float(denominator) diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index adcd9bc14e6..67bbb4d48b4 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -22,7 +22,6 @@ from yacs.config import CfgNode from paddlespeech.audio.transform.transformation import Transformation -from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource @@ -610,22 +609,15 @@ def rescoring(self): hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = self.encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - - r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, - self.model.sos, self.model.eos) - decoder_out, r_decoder_out, _ = self.model.decoder( - encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, - self.model.reverse_weight) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - decoder_out = decoder_out.numpy() + # (beam_size, max_hyps_len, vocab_size) + decoder_out, r_decoder_out = self.model.forward_attention_decoder( + hyps_pad, hyps_lens, self.encoder_out, self.model.reverse_weight) + decoder_out = decoder_out.numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. - r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) r_decoder_out = r_decoder_out.numpy() # Only use decoder score for rescoring diff --git a/tests/unit/asr/reverse_pad_list.py b/tests/unit/asr/reverse_pad_list.py new file mode 100644 index 00000000000..215ed5ceb65 --- /dev/null +++ b/tests/unit/asr/reverse_pad_list.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import paddle + +import paddlespeech.s2t # noqa: F401 +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import pad_sequence + +# from paddlespeech.audio.utils.tensor_utils import reverse_pad_list + + +def reverse_pad_list(ys_pad: paddle.Tensor, + ys_lens: paddle.Tensor, + pad_value: float=-1.0) -> paddle.Tensor: + """Reverse padding for the list of tensors. + Args: + ys_pad (tensor): The padded tensor (B, Tokenmax). + ys_lens (tensor): The lens of token seqs (B) + pad_value (int): Value for padding. + Returns: + Tensor: Padded tensor (B, Tokenmax). + Examples: + >>> x + tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) + >>> pad_list(x, 0) + tensor([[4, 3, 2, 1], + [7, 6, 5, 0], + [9, 8, 0, 0]]) + """ + r_ys_pad = pad_sequence([(paddle.flip(y[:i], [0])) + for y, i in zip(ys_pad, ys_lens)], True, pad_value) + return r_ys_pad + + +def naive_reverse_pad_list_with_sos_eos(r_hyps, + r_hyps_lens, + sos=5000, + eos=5000, + ignore_id=-1): + r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(ignore_id)) + r_hyps, _ = add_sos_eos(r_hyps, sos, eos, ignore_id) + return r_hyps + + +def reverse_pad_list_with_sos_eos(r_hyps, + r_hyps_lens, + sos=5000, + eos=5000, + ignore_id=-1): + # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + max_len = paddle.max(r_hyps_lens) + index_range = paddle.arange(0, max_len, 1) + seq_len_expand = r_hyps_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + + index = (seq_len_expand - 1) - index_range # (beam, max_len) + # >>> index + # >>> tensor([[ 2, 1, 0], + # >>> [ 2, 1, 0], + # >>> [ 0, -1, -2]]) + index = index * seq_mask + + # >>> index + # >>> tensor([[2, 1, 0], + # >>> [2, 1, 0], + # >>> [0, 0, 0]]) + def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + + r_hyps = paddle_gather(r_hyps, 1, index) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, 2, 2]]) + r_hyps = paddle.where(seq_mask, r_hyps, eos) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, eos, eos]]) + B = r_hyps.shape[0] + _sos = paddle.ones([B, 1], dtype=r_hyps.dtype) * sos + # r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) + r_hyps = paddle.concat([_sos, r_hyps], axis=1) + # >>> r_hyps + # >>> tensor([[sos, 3, 2, 1], + # >>> [sos, 4, 8, 9], + # >>> [sos, 2, eos, eos]]) + return r_hyps + + +class TestU2Model(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + + self.sos = 5000 + self.eos = 5000 + self.ignore_id = -1 + self.reverse_hyps = paddle.to_tensor([[4, 3, 2, 1, -1], + [5, 4, 3, 2, 1]]) + self.reverse_hyps_sos_eos = paddle.to_tensor( + [[self.sos, 4, 3, 2, 1, self.eos], [self.sos, 5, 4, 3, 2, 1]]) + + self.hyps = paddle.to_tensor([[1, 2, 3, 4, -1], [1, 2, 3, 4, 5]]) + + self.hyps_lens = paddle.to_tensor([4, 5], paddle.int32) + + def test_reverse_pad_list(self): + r_hyps = reverse_pad_list(self.hyps, self.hyps_lens) + self.assertSequenceEqual(r_hyps.tolist(), self.reverse_hyps.tolist()) + + def test_naive_reverse_pad_list_with_sos_eos(self): + r_hyps_sos_eos = naive_reverse_pad_list_with_sos_eos(self.hyps, + self.hyps_lens) + self.assertSequenceEqual(r_hyps_sos_eos.tolist(), + self.reverse_hyps_sos_eos.tolist()) + + def test_static_reverse_pad_list_with_sos_eos(self): + r_hyps_sos_eos_static = reverse_pad_list_with_sos_eos(self.hyps, + self.hyps_lens) + self.assertSequenceEqual(r_hyps_sos_eos_static.tolist(), + self.reverse_hyps_sos_eos.tolist()) + + +if __name__ == '__main__': + unittest.main()