|
| 1 | +from typing import List |
| 2 | + |
| 3 | +import k2 |
| 4 | +import torch |
| 5 | +from stream import Stream, stack_states, unstack_states |
| 6 | + |
| 7 | +from sherpa import fast_beam_search_one_best, streaming_greedy_search |
| 8 | + |
| 9 | + |
| 10 | +class FastBeamSearch: |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + vocab_size: int, |
| 14 | + context_size: int, |
| 15 | + beam: int, |
| 16 | + max_states: int, |
| 17 | + max_contexts: int, |
| 18 | + device: torch.device, |
| 19 | + ): |
| 20 | + """ |
| 21 | + Args: |
| 22 | + vocab_size: |
| 23 | + Vocabularize of the BPE |
| 24 | + context_size: |
| 25 | + Context size of the RNN-T decoder model. |
| 26 | + beam: |
| 27 | + The beam for fast_beam_search decoding. |
| 28 | + max_states: |
| 29 | + The max_states for fast_beam_search decoding. |
| 30 | + max_contexts: |
| 31 | + The max_contexts for fast_beam_search decoding. |
| 32 | + device: |
| 33 | + Device on which the computation will occur |
| 34 | + """ |
| 35 | + self.rnnt_decoding_config = k2.RnntDecodingConfig( |
| 36 | + vocab_size=vocab_size, |
| 37 | + decoder_history_len=context_size, |
| 38 | + beam=beam, |
| 39 | + max_states=max_states, |
| 40 | + max_contexts=max_contexts, |
| 41 | + ) |
| 42 | + self.decoding_graph = k2.trivial_graph(vocab_size - 1, device) |
| 43 | + self.device = device |
| 44 | + self.context_size = context_size |
| 45 | + |
| 46 | + def init_stream(self, stream: Stream): |
| 47 | + """ |
| 48 | + Attributes to add to each stream |
| 49 | + """ |
| 50 | + stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph) |
| 51 | + stream.hyp = [] |
| 52 | + |
| 53 | + @torch.no_grad() |
| 54 | + def process( |
| 55 | + self, |
| 56 | + server: "StreamingServer", |
| 57 | + stream_list: List[Stream], |
| 58 | + ) -> None: |
| 59 | + """Run the model on the given stream list and do search with fast_beam_search |
| 60 | + method. |
| 61 | + Args: |
| 62 | + server: |
| 63 | + An instance of `StreamingServer`. |
| 64 | + stream_list: |
| 65 | + A list of streams to be processed. It is changed in-place. |
| 66 | + That is, the attribute `states` and `hyp` are |
| 67 | + updated in-place. |
| 68 | + """ |
| 69 | + model = server.model |
| 70 | + device = model.device |
| 71 | + # Note: chunk_length is in frames before subsampling |
| 72 | + chunk_length = server.chunk_length |
| 73 | + batch_size = len(stream_list) |
| 74 | + chunk_length_pad = server.chunk_length_pad |
| 75 | + state_list, feature_list = [], [] |
| 76 | + processed_frames_list, rnnt_decoding_streams_list = [], [] |
| 77 | + |
| 78 | + rnnt_decoding_config = self.rnnt_decoding_config |
| 79 | + for s in stream_list: |
| 80 | + rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) |
| 81 | + state_list.append(s.states) |
| 82 | + processed_frames_list.append(s.processed_frames) |
| 83 | + f = s.features[:chunk_length_pad] |
| 84 | + s.features = s.features[chunk_length:] |
| 85 | + s.processed_frames += chunk_length |
| 86 | + |
| 87 | + b = torch.cat(f, dim=0) |
| 88 | + feature_list.append(b) |
| 89 | + |
| 90 | + features = torch.stack(feature_list, dim=0).to(device) |
| 91 | + |
| 92 | + states = stack_states(state_list) |
| 93 | + |
| 94 | + features_length = torch.full( |
| 95 | + (batch_size,), |
| 96 | + fill_value=features.size(1), |
| 97 | + device=device, |
| 98 | + dtype=torch.int64, |
| 99 | + ) |
| 100 | + |
| 101 | + num_processed_frames = torch.tensor( |
| 102 | + processed_frames_list, device=device |
| 103 | + ) |
| 104 | + |
| 105 | + ( |
| 106 | + encoder_out, |
| 107 | + encoder_out_lens, |
| 108 | + next_states, |
| 109 | + ) = model.encoder_streaming_forward( |
| 110 | + features=features, |
| 111 | + features_length=features_length, |
| 112 | + num_processed_frames=num_processed_frames, |
| 113 | + states=states, |
| 114 | + ) |
| 115 | + |
| 116 | + processed_lens = (num_processed_frames >> 2) + encoder_out_lens |
| 117 | + next_hyp_list = fast_beam_search_one_best( |
| 118 | + model=model, |
| 119 | + encoder_out=encoder_out, |
| 120 | + processed_lens=processed_lens, |
| 121 | + rnnt_decoding_config=rnnt_decoding_config, |
| 122 | + rnnt_decoding_streams_list=rnnt_decoding_streams_list, |
| 123 | + ) |
| 124 | + |
| 125 | + next_state_list = unstack_states(next_states) |
| 126 | + for i, s in enumerate(stream_list): |
| 127 | + s.states = next_state_list[i] |
| 128 | + s.hyp = next_hyp_list[i] |
| 129 | + |
| 130 | + def get_texts(self, stream: Stream): |
| 131 | + """ |
| 132 | + Return text after decoding |
| 133 | + Args: |
| 134 | + stream: |
| 135 | + Stream to be processed. |
| 136 | + """ |
| 137 | + return self.sp.decode(stream.hyp) |
| 138 | + |
| 139 | + |
| 140 | +class GreedySearch: |
| 141 | + def __init__(self, model: "RnntConvEmformerModel", device: torch.device): |
| 142 | + """ |
| 143 | + Args: |
| 144 | + model: |
| 145 | + RNN-T model decoder model |
| 146 | + device: |
| 147 | + Device on which the computation will occur |
| 148 | + """ |
| 149 | + |
| 150 | + self.blank_id = model.blank_id |
| 151 | + self.context_size = model.context_size |
| 152 | + self.device = device |
| 153 | + |
| 154 | + decoder_input = torch.tensor( |
| 155 | + [[self.blank_id] * self.context_size], |
| 156 | + device=self.device, |
| 157 | + dtype=torch.int64, |
| 158 | + ) |
| 159 | + initial_decoder_out = model.decoder_forward(decoder_input) |
| 160 | + self.initial_decoder_out = model.forward_decoder_proj( |
| 161 | + initial_decoder_out.squeeze(1) |
| 162 | + ) |
| 163 | + |
| 164 | + def init_stream(self, stream: Stream): |
| 165 | + """ |
| 166 | + Attributes to add to each stream |
| 167 | + """ |
| 168 | + stream.decoder_out = self.initial_decoder_out |
| 169 | + stream.hyp = [self.blank_id] * self.context_size |
| 170 | + |
| 171 | + @torch.no_grad() |
| 172 | + def process( |
| 173 | + self, |
| 174 | + server: "StreamingServer", |
| 175 | + stream_list: List[Stream], |
| 176 | + ) -> None: |
| 177 | + """Run the model on the given stream list and do search with greedy_search |
| 178 | + method. |
| 179 | + Args: |
| 180 | + server: |
| 181 | + An instance of `StreamingServer`. |
| 182 | + stream_list: |
| 183 | + A list of streams to be processed. It is changed in-place. |
| 184 | + That is, the attribute `states` and `hyp` are |
| 185 | + updated in-place. |
| 186 | + """ |
| 187 | + model = server.model |
| 188 | + device = model.device |
| 189 | + # Note: chunk_length is in frames before subsampling |
| 190 | + chunk_length = server.chunk_length |
| 191 | + batch_size = len(stream_list) |
| 192 | + chunk_length_pad = server.chunk_length_pad |
| 193 | + state_list, feature_list = [], [] |
| 194 | + decoder_out_list, hyp_list = [], [] |
| 195 | + processed_frames_list = [] |
| 196 | + |
| 197 | + for s in stream_list: |
| 198 | + decoder_out_list.append(s.decoder_out) |
| 199 | + hyp_list.append(s.hyp) |
| 200 | + state_list.append(s.states) |
| 201 | + processed_frames_list.append(s.processed_frames) |
| 202 | + f = s.features[:chunk_length_pad] |
| 203 | + s.features = s.features[chunk_length:] |
| 204 | + s.processed_frames += chunk_length |
| 205 | + |
| 206 | + b = torch.cat(f, dim=0) |
| 207 | + feature_list.append(b) |
| 208 | + |
| 209 | + features = torch.stack(feature_list, dim=0).to(device) |
| 210 | + states = stack_states(state_list) |
| 211 | + decoder_out = torch.cat(decoder_out_list, dim=0) |
| 212 | + |
| 213 | + features_length = torch.full( |
| 214 | + (batch_size,), |
| 215 | + fill_value=features.size(1), |
| 216 | + device=device, |
| 217 | + dtype=torch.int64, |
| 218 | + ) |
| 219 | + |
| 220 | + num_processed_frames = torch.tensor( |
| 221 | + processed_frames_list, device=device |
| 222 | + ) |
| 223 | + |
| 224 | + ( |
| 225 | + encoder_out, |
| 226 | + encoder_out_lens, |
| 227 | + next_states, |
| 228 | + ) = model.encoder_streaming_forward( |
| 229 | + features=features, |
| 230 | + features_length=features_length, |
| 231 | + num_processed_frames=num_processed_frames, |
| 232 | + states=states, |
| 233 | + ) |
| 234 | + |
| 235 | + # Note: It does not return the next_encoder_out_len since |
| 236 | + # there are no paddings for streaming ASR. Each stream |
| 237 | + # has the same input number of frames, i.e., server.chunk_length. |
| 238 | + next_decoder_out, next_hyp_list = streaming_greedy_search( |
| 239 | + model=model, |
| 240 | + encoder_out=encoder_out, |
| 241 | + decoder_out=decoder_out, |
| 242 | + hyps=hyp_list, |
| 243 | + ) |
| 244 | + |
| 245 | + next_decoder_out_list = next_decoder_out.split(1) |
| 246 | + |
| 247 | + next_state_list = unstack_states(next_states) |
| 248 | + for i, s in enumerate(stream_list): |
| 249 | + s.states = next_state_list[i] |
| 250 | + s.decoder_out = next_decoder_out_list[i] |
| 251 | + s.hyp = next_hyp_list[i] |
| 252 | + |
| 253 | + def get_texts(self, stream: Stream): |
| 254 | + """ |
| 255 | + Return text after decoding |
| 256 | + Args: |
| 257 | + stream: |
| 258 | + Stream to be processed. |
| 259 | + """ |
| 260 | + return self.sp.decode(stream.hyp[self.context_size :]) |
0 commit comments