Skip to content

Commit

Permalink
[ASR] support wav2vec2-zh cli, test=asr (#2697)
Browse files Browse the repository at this point in the history
* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr
  • Loading branch information
Zth9730 authored Nov 29, 2022
1 parent a01c163 commit c67bf7b
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 22 deletions.
32 changes: 25 additions & 7 deletions paddlespeech/cli/ssl/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
import paddle
import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode

from ..executor import BaseExecutor
Expand All @@ -50,7 +51,7 @@ def __init__(self):
self.parser.add_argument(
'--model',
type=str,
default='wav2vec2ASR_librispeech',
default=None,
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(self):
help='Increase logger verbosity of current task.')

def _init_from_path(self,
model_type: str='wav2vec2ASR_librispeech',
model_type: str=None,
task: str='asr',
lang: str='en',
sample_rate: int=16000,
Expand All @@ -134,6 +135,18 @@ def _init_from_path(self,
Init model and other resources from a specific path.
"""
logger.debug("start to init the model")

if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
# default max_len: unit:second
self.max_len = 50
if hasattr(self, 'model'):
Expand Down Expand Up @@ -167,9 +180,13 @@ def _init_from_path(self,
self.config.merge_from_file(self.cfg_path)
if task == 'asr':
with UpdateConfig(self.config):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
if lang == 'en':
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
Expand Down Expand Up @@ -253,7 +270,8 @@ def infer(self, model_type: str, task: str):
audio,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size)
beam_size=cfg.beam_size,
tokenizer=getattr(self.config, 'tokenizer', None))
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -413,7 +431,7 @@ def execute(self, argv: List[str]) -> bool:
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='wav2vec2ASR_librispeech',
model: str=None,
task: str='asr',
lang: str='en',
sample_rate: int=16000,
Expand Down
32 changes: 32 additions & 0 deletions paddlespeech/resource/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,38 @@
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"wav2vec2-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz',
'md5':
'00ea4975c05d1bb58181205674052fe1',
'cfg_path':
'model.yaml',
'ckpt_path':
'chinese-wav2vec2-large',
'model':
'chinese-wav2vec2-large.pdparams',
'params':
'chinese-wav2vec2-large.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz',
'md5':
'ac8fa0a6345e6a7535f6fabb5e59e218',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}

# ---------------------------------
Expand Down
4 changes: 0 additions & 4 deletions paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,10 +1173,6 @@ def __init__(self, config):
self.proj_codevector_dim = config.proj_codevector_dim
self.diversity_loss_weight = config.diversity_loss_weight

# ctc loss
self.ctc_loss_reduction = config.ctc_loss_reduction
self.ctc_zero_infinity = config.ctc_zero_infinity

# adapter
self.add_adapter = config.add_adapter
self.adapter_kernel_size = config.adapter_kernel_size
Expand Down
60 changes: 49 additions & 11 deletions paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,66 @@ def decode(self,
feats: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int):
beam_size: int,
tokenizer: str=None):
batch_size = feats.shape[0]

if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
raise ValueError(
f"decoding mode {decoding_method} must be running with batch_size == 1"
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)

if decoding_method == 'ctc_greedy_search':
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
if tokenizer is None:
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
if tokenizer is None:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else:
raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
Expand Down

0 comments on commit c67bf7b

Please sign in to comment.