Skip to content

Commit

Permalink
Merge pull request #2502 from zh794390558/u2pp_export
Browse files Browse the repository at this point in the history
[s2t]  streaming conformer u2 and u2pp jit export
  • Loading branch information
Zth9730 authored Oct 9, 2022
2 parents 0359c3f + c98b5dd commit c9b0c96
Show file tree
Hide file tree
Showing 32 changed files with 1,142 additions and 217 deletions.
33 changes: 33 additions & 0 deletions examples/wenetspeech/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,36 @@ show model.tar.gz
```
tar tf model.tar.gz
```

other way is:

```bash
tar cvzf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz model.yaml conf/tuning/ conf/chunk_conformer.yaml conf/preprocess.yaml data/mean_std.json exp/chunk_conformer/checkpoints/
```

## Export Static Model

>> Need Paddle >= 2.4
>> `data/test_meeting/data.list`
>> {"input": [{"name": "input1", "shape": [3.2230625, 80], "feat": "/home/PaddleSpeech/dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0163.wav", "filetype": "sound"}], "output": [{"name": "target1", "shape": [9, 5538], "text": "\u697c\u5e02\u8c03\u63a7\u5c06\u53bb\u5411\u4f55\u65b9", "token": "\u697c \u5e02 \u8c03 \u63a7 \u5c06 \u53bb \u5411 \u4f55 \u65b9", "tokenid": "1891 1121 3502 1543 1018 477 528 163 1657"}], "utt": "BAC009S0764W0163", "utt2spk": "S0764"}
>> Test Wav:
>> wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
### U2 chunk conformer
>> UiDecoder
>> Make sure `reverse_weight` in config is `0.0`
>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz
```
tar zxvf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz
./local/export.sh conf/chunk_conformer.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji
```

### U2++ chunk conformer
>> BiDecoder
>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.0.model.tar.gz
>> Make sure `reverse_weight` in config is not `0.0`
```
./local/export.sh conf/chunk_conformer_u2pp.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji
```
101 changes: 101 additions & 0 deletions examples/wenetspeech/asr1/conf/chunk_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: swish
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.0 # unidecoder
length_normalized_loss: false
init_type: 'kaiming_uniform'

# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list

###########################################
# Dataloader #
###########################################
use_streaming_data: True
unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
preprocess_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
do_filter: True
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
minlen_in: 10
minlen_out: 0
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1


###########################################
# Training #
###########################################
n_epoch: 26
accum_grad: 32
global_grad_clip: 5.0
dist_sampler: True
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 5000
lr_decay: 1.0
100 changes: 100 additions & 0 deletions examples/wenetspeech/asr1/conf/chunk_conformer_u2pp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: swish
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 3 # the number of encoder blocks
r_num_blocks: 3 #only for bitransformer
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3 # only for bitransformer decoder
init_type: 'kaiming_uniform' # !Warning: need to convergence

###########################################
# Data #
###########################################
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list

###########################################
# Dataloader #
###########################################
use_stream_data: True
vocab_filepath: data/lang_char/vocab.txt
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
do_filter: True
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
minlen_in: 10
minlen_out: 0
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1

###########################################
# Training #
###########################################
n_epoch: 150
accum_grad: 8
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
2 changes: 1 addition & 1 deletion examples/wenetspeech/asr1/conf/preprocess.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: 0.1
dither: 1.0
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument
Expand Down
12 changes: 12 additions & 0 deletions examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: True # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer
9 changes: 5 additions & 4 deletions examples/wenetspeech/asr1/conf/tuning/decode.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
simulate_streaming: False # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer
5 changes: 5 additions & 0 deletions examples/wenetspeech/asr1/local/export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3


# export can not using StreamdataDataloader, set use_stream_dta False
# u2: reverse_weight should be 0.0
# u2pp: reverse_weight should be same with config file. e.g. 0.3
python3 -u ${BIN_DIR}/export.py \
--ngpu ${ngpu} \
--config ${config_path} \
--opts use_stream_data False \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}

Expand Down
59 changes: 59 additions & 0 deletions examples/wenetspeech/asr1/local/quant.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash

if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1
fi

ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."

config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4

mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi

if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
exit 1
fi


chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi

# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi

for type in attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/quant.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}

if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
Loading

0 comments on commit c9b0c96

Please sign in to comment.