Skip to content

输出词级别时间戳时(output_timestamp=True),无法批量推理 #207

@pyssionately

Description

@pyssionately

res = model.generate(
input=files,
cache={},
language="en", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=False,
batch_size=4,
output_timestamp=True
)

不添加output_timestamp=True 可以正常批量推理,添加了output_timestamp=True会报如下错误:
RuntimeError Traceback (most recent call last)
Cell In[3], line 20
6 files = '''
7 /data/training_share/l00845921/en2zh_wav/en_zh_mp3/74_food_part1_01/74_food_part1_01_5gCGBWiazhY#00-00-07_0.mp3
8 /data/training_share/l00845921/en2zh_wav/en_zh_mp3/74_food_part1_01/74_food_part1_01_6UPkwUvDk9Y#00-00-13_0.mp3
(...)
14 /data/training_share/l00845921/en2zh_wav/en_zh_mp3/74_food_part1_01/74_food_part1_01_3ucLC6aje3w#00-00-15_3.mp3
15 '''.strip().split('\n')
18 model = AutoModel(model=model_dir, trust_remote_code=True, device="cuda:0")
---> 20 res = model.generate(
21 input=files,
22 cache={},
23 language="en", # "zh", "en", "yue", "ja", "ko", "nospeech"
24 use_itn=False,
25 batch_size=4,
26 output_timestamp=True
27 )

File /usr/local/lib/python3.11/site-packages/funasr/auto/auto_model.py:260, in AutoModel.generate(self, input, input_len, **cfg)
258 def generate(self, input, input_len=None, **cfg):
259 if self.vad_model is None:
--> 260 return self.inference(input, input_len=input_len, **cfg)
262 else:
263 return self.inference_with_vad(input, input_len=input_len, **cfg)

File /usr/local/lib/python3.11/site-packages/funasr/auto/auto_model.py:302, in AutoModel.inference(self, input, input_len, model, kwargs, key, **cfg)
300 time1 = time.perf_counter()
301 with torch.no_grad():
--> 302 res = model.inference(**batch, **kwargs)
303 if isinstance(res, (list, tuple)):
304 results = res[0] if len(res) > 0 else [{"text": ""}]

File /data/xyh/SenseVoice/model.py:904, in SenseVoiceSmall.inference(self, data_in, data_lengths, key, tokenizer, frontend, **kwargs)
894 logits_speech[pred==self.blank_id, self.blank_id] = 0
896 align = ctc_forced_align(
897 logits_speech.unsqueeze(0).float(),
898 torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
(...)
901 ignore_id=self.ignore_id,
902 )
--> 904 pred = groupby(align[0, :encoder_out_lens[0]])
905 _start = 0
906 token_id = 0

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions