Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[NeuralChat] Fix tts crash with messy retrieval input and enhance normalizer #1088

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from datasets import load_dataset, Audio, Dataset, Features, ClassLabel
import os
import torch
from speechbrain.pretrained import EncoderClassifier
from typing import Any, Dict, List, Union
from transformers import SpeechT5HifiGan
import soundfile as sf
Expand Down Expand Up @@ -59,6 +58,7 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m
self.stream_mode = stream_mode
self.spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
try:
from speechbrain.pretrained import EncoderClassifier
self.speaker_model = EncoderClassifier.from_hparams(
source=self.spk_model_name,
run_opts={"device": "cpu"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
class EnglishNormalizer:
def __init__(self):
self.correct_dict = {
"A": "Eigh",
"A": "eigh",
"B": "bee",
"C": "cee",
"D": "dee",
"E": "yee",
"F": "ef",
"G": "jee",
"H": "aitch",
"I": "I",
"I": "eye",
"J": "jay",
"K": "kay",
"L": "el",
Expand All @@ -58,8 +58,7 @@ def __init__(self):
def correct_abbreviation(self, text):
# TODO mixed abbreviation or proper noun like i7, ffmpeg, BTW should be supported

# words = text.split() # CVPR-15 will be upper but 1 and 5 will be splitted to two numbers
words = re.split(' |_|/', text)
words = re.split(r' |_|/|\*|\#', text) # ignore the characters that not break sentence
results = []
for idx, word in enumerate(words):
if word.startswith("-"): # bypass negative number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_correct_conjunctions(self):
text = "CVPR-15 ICML-21 PM2.5"
text = self.normalizer.correct_abbreviation(text)
result = self.normalizer.correct_number(text)
self.assertEqual(result, "cee vee pea ar fifteen I cee em el twenty-one pea em two point five.")
self.assertEqual(result, "cee vee pea ar fifteen eye cee em el twenty-one pea em two point five.")

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def test_tts_long_text(self):
output_audio_path = os.path.join(os.getcwd(), "tmp_audio/2.wav")
set_seed(555)
output_audio_path = self.tts.text2speech(text, output_audio_path, voice="default", do_batch_tts=True, batch_length=120)
result = self.asr.audio2text(output_audio_path)
self.assertTrue(os.path.exists(output_audio_path))
self.assertEqual("intel extension for transformers is an innovative toolkit to accelerate transformer based " + \
"models on intel platforms in particular effective on 4th intel xeon scalable processor " + \
"sapphire rapids codenamed sapphire rapids", result)

def test_create_speaker_embedding(self):
driven_audio_path = \
Expand All @@ -117,5 +121,15 @@ def test_tts_remove_noise(self):
result = self.asr.audio2text(output_audio_path)
self.assertEqual(text.lower(), result.lower())

def test_tts_messy_input(self):
text = "Please refer to the following responses to this inquiry:\n" + 244 * "* " + "*"
output_audio_path = os.path.join(os.getcwd(), "tmp_audio/6.wav")
set_seed(555)
output_audio_path = self.tts_noise_reducer.text2speech(text, output_audio_path, voice="default")
self.assertTrue(os.path.exists(output_audio_path))
# verify accuracy
result = self.asr.audio2text(output_audio_path)
self.assertEqual("please refer to the following responses to this inquiry", result.lower())

if __name__ == "__main__":
unittest.main()