Skip to content

Commit

Permalink
Streaming asr time stamp (#146)
Browse files Browse the repository at this point in the history
* WIP: Add timestamps for streaming ASR

* Add timestamps for streaming modified beam search

* Add timestamps for streaming fast_beam_search

* Add timestamp to other recipe

* Add timestamp to other recipes

* Fix cpp style

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
  • Loading branch information
ezerhouni and csukuangfj authored Oct 4, 2022
1 parent b163428 commit 0649c8a
Show file tree
Hide file tree
Showing 26 changed files with 666 additions and 103 deletions.
89 changes: 77 additions & 12 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(

processed_lens = (num_processed_frames >> 2) + encoder_out_lens
if self.decoding_method == "fast_beam_search_nbest":
next_hyp_list, next_trailing_blank_frames = fast_beam_search_nbest(
res = fast_beam_search_nbest(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -145,10 +145,7 @@ def process(
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search_nbest_LG":
(
next_hyp_list,
next_trailing_blank_frames,
) = fast_beam_search_nbest_LG(
res = fast_beam_search_nbest_LG(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -160,10 +157,7 @@ def process(
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
(
next_hyp_list,
next_trailing_blank_frames,
) = fast_beam_search_one_best(
res = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -178,8 +172,13 @@ def process(
next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyp = next_hyp_list[i]
s.num_trailing_blank_frames = next_trailing_blank_frames[i]
s.hyp = res.hyps[i]
s.num_trailing_blank_frames = res.num_trailing_blanks[i]

s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)
s.timestamps = res.timestamps[i]
s.tokens = res.tokens[i]

def get_texts(self, stream: Stream) -> str:
"""
Expand All @@ -196,6 +195,22 @@ def get_texts(self, stream: Stream) -> str:

return result

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
tokens = stream.tokens

if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in tokens]
else:
result = [self.token_table[i] for i in tokens]

return result


class GreedySearch:
def __init__(
Expand Down Expand Up @@ -239,6 +254,7 @@ def init_stream(self, stream: Stream):
stream.hyp = [
self.beam_search_params["blank_id"]
] * self.beam_search_params["context_size"]
stream.timestamps = [] # containing frame numbers after subsampling

@torch.no_grad()
def process(
Expand Down Expand Up @@ -266,6 +282,7 @@ def process(
decoder_out_list, hyp_list = [], []
processed_frames_list = []
num_trailing_blank_frames_list = []
frame_offset_list, timestamps_list = [], []

for s in stream_list:
decoder_out_list.append(s.decoder_out)
Expand All @@ -280,6 +297,8 @@ def process(
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)
frame_offset_list.append(s.segment_frame_offset)
timestamps_list.append(s.timestamps)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)
Expand Down Expand Up @@ -315,12 +334,15 @@ def process(
next_decoder_out,
next_hyp_list,
next_trailing_blank_frames,
next_timestamps,
) = streaming_greedy_search(
model=model,
encoder_out=encoder_out,
decoder_out=decoder_out,
hyps=hyp_list,
num_trailing_blank_frames=num_trailing_blank_frames_list,
frame_offset=frame_offset_list,
timestamps=timestamps_list,
)

next_decoder_out_list = next_decoder_out.split(1)
Expand All @@ -331,6 +353,9 @@ def process(
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]
s.num_trailing_blank_frames = next_trailing_blank_frames[i]
s.timestamps = next_timestamps[i]
s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)

def get_texts(self, stream: Stream) -> str:
"""
Expand All @@ -343,6 +368,21 @@ def get_texts(self, stream: Stream) -> str:
stream.hyp[self.beam_search_params["context_size"] :]
)

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
hyp = stream.hyp[self.beam_search_params["context_size"] :]
if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in hyp]
else:
result = [self.token_table[i] for i in hyp]

return result


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
Expand Down Expand Up @@ -382,7 +422,7 @@ def process(
state_list, feature_list = [], []
hyp_list = []
processed_frames_list = []
num_trailing_blank_frames_list = []
num_trailing_blank_frames_list, frame_offset_list = [], []

for s in stream_list:
hyp_list.append(s.hyps)
Expand All @@ -396,6 +436,7 @@ def process(
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)
frame_offset_list.append(s.segment_frame_offset)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)
Expand Down Expand Up @@ -429,6 +470,7 @@ def process(
model=model,
encoder_out=encoder_out,
hyps=hyp_list,
frame_offset=frame_offset_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

Expand All @@ -437,7 +479,13 @@ def process(
s.states = next_state_list[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks
best_hyp = s.hyps.get_most_probable(True)

trailing_blanks = best_hyp.num_trailing_blanks
s.timestamps = best_hyp.timestamps
s.num_trailing_blank_frames = trailing_blanks
s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
Expand All @@ -450,3 +498,20 @@ def get_texts(self, stream: Stream) -> str:
result = "".join(result)

return result

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in hyp]
else:
result = [self.token_table[i] for i in hyp]

return result
6 changes: 6 additions & 0 deletions sherpa/bin/conv_emformer_transducer_stateless2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def __init__(
# whenever an endpoint is detected, it is incremented
self.segment = 0

# Number of frames decoded so far (after subsampling)
self.frame_offset = 0 # never reset

# frame offset within the current segment after subsampling
self.segment_frame_offset = 0 # reset on endpointing

def accept_waveform(
self,
sampling_rate: float,
Expand Down
24 changes: 24 additions & 0 deletions sherpa/bin/conv_emformer_transducer_stateless2/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
RnntConvEmformerModel,
add_beam_search_arguments,
add_online_endpoint_arguments,
convert_timestamp,
)


Expand Down Expand Up @@ -264,6 +265,7 @@ def __init__(
beam_search_params["blank_id"] = self.blank_id

decoding_method = beam_search_params["decoding_method"]
self.decoding_method = decoding_method
if decoding_method.startswith("fast_beam_search"):
self.beam_search = FastBeamSearch(
beam_search_params=beam_search_params,
Expand Down Expand Up @@ -464,15 +466,27 @@ async def handle_connection_impl(
while len(stream.features) > self.chunk_length_pad:
await self.compute_and_decode(stream)
hyp = self.beam_search.get_texts(stream)
tokens = self.beam_search.get_tokens(stream)

segment = stream.segment
timestamps = convert_timestamp(
frames=stream.timestamps,
subsampling_factor=stream.subsampling_factor,
)

frame_offset = stream.frame_offset * stream.subsampling_factor

is_final = stream.endpoint_detected(self.online_endpoint_config)
if is_final:
self.beam_search.init_stream(stream)

message = {
"method": self.decoding_method,
"segment": segment,
"frame_offset": frame_offset,
"text": hyp,
"tokens": tokens,
"timestamps": timestamps,
"final": is_final,
}

Expand All @@ -489,10 +503,20 @@ async def handle_connection_impl(
stream.features = []

hyp = self.beam_search.get_texts(stream)
tokens = self.beam_search.get_tokens(stream)
frame_offset = stream.frame_offset * stream.subsampling_factor
timestamps = convert_timestamp(
frames=stream.timestamps,
subsampling_factor=stream.subsampling_factor,
)

message = {
"method": self.decoding_method,
"segment": stream.segment,
"frame_offset": frame_offset,
"text": hyp,
"tokens": tokens,
"timestamps": timestamps,
"final": True, # end of connection, always set final to True
}

Expand Down
Loading

0 comments on commit 0649c8a

Please sign in to comment.