Skip to content

Commit

Permalink
add auto_rerank part (#393)
Browse files Browse the repository at this point in the history
* add auto_rerank part

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* swin to UTF-8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PoTaTo-Mika and pre-commit-ci[bot] authored Jul 18, 2024
1 parent 04b6c10 commit dc250ab
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 8 deletions.
35 changes: 35 additions & 0 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# from fish_speech.models.vqgan.lit_module import VQGAN
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -293,6 +294,39 @@ def inference(req: InvokeRequest):
yield fake_audios


def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
if not use_auto_rerank:
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
return inference(req)

zh_model, en_model = load_model()
max_attempts = 5
best_wer = float("inf")
best_audio = None

for attempt in range(max_attempts):
# 调用原始的 inference 函数
audio_generator = inference(req)
fake_audios = next(audio_generator)

asr_result = batch_asr(
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
)[0]
wer = calculate_wer(req.text, asr_result["text"])

if wer <= 0.1 and not asr_result["huge_gap"]:
return fake_audios

if wer < best_wer:
best_wer = wer
best_audio = fake_audios

if attempt == max_attempts - 1:
break

return best_audio


async def inference_async(req: InvokeRequest):
for chunk in inference(req):
yield chunk
Expand Down Expand Up @@ -377,6 +411,7 @@ def parse_args():
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--use-auto-rerank", type=bool, default=True)

return parser.parse_args()

Expand Down
126 changes: 126 additions & 0 deletions tools/auto_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import time
from threading import Lock

import numpy as np
import torch
import torchaudio
from funasr import AutoModel
from funasr.models.seaco_paraformer.model import SeacoParaformer

# Monkey patching to disable hotwords
SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None


def load_model(*, device="cuda"):
zh_model = AutoModel(
model="paraformer-zh",
device=device,
disable_pbar=True,
)
en_model = AutoModel(
model="paraformer-en",
device=device,
disable_pbar=True,
)

return zh_model, en_model


@torch.no_grad()
def batch_asr_internal(model, audios, sr):
resampled_audios = []
for audio in audios:
# 将 NumPy 数组转换为 PyTorch 张量
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()

# 确保音频是一维的
if audio.dim() > 1:
audio = audio.squeeze()

audio = torchaudio.functional.resample(audio, sr, 16000)
assert audio.dim() == 1
resampled_audios.append(audio)

res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))

results = []
for r, audio in zip(res, audios):
text = r["text"]
duration = len(audio) / sr * 1000
huge_gap = False

if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 5 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > 5000:
huge_gap = True
break

# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > 3000:
huge_gap = True

results.append(
{
"text": text,
"duration": duration,
"huge_gap": huge_gap,
}
)

return results


global_lock = Lock()


def batch_asr(model, audios, sr):
return batch_asr_internal(model, audios, sr)


def is_chinese(text):
return True


def calculate_wer(text1, text2):
words1 = text1.split()
words2 = text2.split()

# 计算编辑距离
m, n = len(words1), len(words2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j

for i in range(1, m + 1):
for j in range(1, n + 1):
if words1[i - 1] == words2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1

# 计算WER
edits = dp[m][n]
wer = edits / len(words1)

return wer


if __name__ == "__main__":
zh_model, en_model = load_model()
audios = [
torchaudio.load("lengyue.wav")[0][0],
torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
]
print(batch_asr(zh_model, audios, 44100))

start_time = time.time()
for _ in range(10):
batch_asr(zh_model, audios, 44100)
print("Time taken:", time.time() - start_time)
88 changes: 80 additions & 8 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fish_speech.i18n import i18n
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from tools.api import decode_vq_tokens, encode_reference
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -162,7 +163,81 @@ def inference(
gc.collect()


inference_stream = partial(inference, streaming=True)
def inference_with_auto_rerank(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
streaming=False,
use_auto_rerank=True,
):
if not use_auto_rerank:
return inference(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
streaming,
)

zh_model, en_model = load_model()
max_attempts = 2
best_wer = float("inf")
best_audio = None
best_sample_rate = None

for attempt in range(max_attempts):
audio_generator = inference(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
streaming=False,
)

# 获取音频数据
for _ in audio_generator:
pass
_, (sample_rate, audio), message = _

if audio is None:
return None, None, message

asr_result = batch_asr(
zh_model if is_chinese(text) else en_model, [audio], sample_rate
)[0]
wer = calculate_wer(text, asr_result["text"])

if wer <= 0.3 and not asr_result["huge_gap"]:
return None, (sample_rate, audio), None

if wer < best_wer:
best_wer = wer
best_audio = audio
best_sample_rate = sample_rate

if attempt == max_attempts - 1:
break

return None, (best_sample_rate, best_audio), None


inference_stream = partial(inference_with_auto_rerank, streaming=True)

n_audios = 4

Expand All @@ -186,7 +261,7 @@ def inference_wrapper(
errors = []

for _ in range(batch_infer_num):
items = inference(
result = inference_with_auto_rerank(
text,
enable_reference_audio,
reference_audio,
Expand All @@ -198,16 +273,13 @@ def inference_wrapper(
temperature,
)

try:
item = next(items)
except StopIteration:
print("No more audio data available.")
_, audio_data, error_message = result

audios.append(
gr.Audio(value=item[1] if (item and item[1]) else None, visible=True),
gr.Audio(value=audio_data if audio_data else None, visible=True),
)
errors.append(
gr.HTML(value=item[2] if (item and item[2]) else None, visible=True),
gr.HTML(value=error_message if error_message else None, visible=True),
)

for _ in range(batch_infer_num, n_audios):
Expand Down

0 comments on commit dc250ab

Please sign in to comment.