Skip to content

Commit 5d4d7f5

Browse files
committed
in grads more WIP
1 parent 63bd426 commit 5d4d7f5

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
from __future__ import annotations
2+
from typing import Optional, Any, Dict, List, Tuple
3+
from sisyphus import Job, Task, tk
4+
from i6_experiments.users.zeyer.external_models.huggingface import get_content_dir_from_hub_cache_dir
5+
from i6_experiments.users.zeyer.sis_tools.instanciate_delayed import instanciate_delayed_copy
6+
7+
8+
class ChunkSegmentationFromModelLongFormJob(Job):
9+
"""
10+
Long-form variant
11+
"""
12+
13+
def __init__(
14+
self,
15+
*,
16+
dataset_dir: tk.Path,
17+
dataset_key: str,
18+
returnn_root: Optional[tk.Path] = None,
19+
model_config: Dict[str, Any],
20+
chunk_size_secs: float = 30.0,
21+
chunk_overlap_secs: float = 5.0,
22+
empty_exit_penalty: float = -5.0,
23+
word_start_heuristic: bool = True,
24+
dump_wav_first_n_seqs: int = 0,
25+
):
26+
"""
27+
:param model_dir: hub cache dir of model e.g. via DownloadHuggingFaceRepoJob.out_hub_cache_dir
28+
:param dataset_dir: hub cache dir, e.g. via DownloadHuggingFaceRepoJobV2. for load_dataset
29+
:param dataset_key: e.g. "train", "test", whatever the dataset provides
30+
:param returnn_root:
31+
:param speech_prompt: prompt to use for the audio
32+
:param chunk_size_secs: chunk size in seconds
33+
:param chunk_overlap_secs:
34+
:param empty_exit_penalty:
35+
:param word_start_heuristic:
36+
:param dump_wav_first_n_seqs: for debugging
37+
"""
38+
super().__init__()
39+
40+
self.dataset_dir = dataset_dir
41+
self.dataset_key = dataset_key
42+
self.returnn_root = returnn_root
43+
self.model_config = model_config
44+
45+
self.chunk_size_secs = chunk_size_secs
46+
self.chunk_overlap_secs = chunk_overlap_secs
47+
self.empty_exit_penalty = empty_exit_penalty
48+
self.word_start_heuristic = word_start_heuristic
49+
self.dump_wav_first_n_seqs = dump_wav_first_n_seqs
50+
51+
self.rqmt = {"time": 40, "cpu": 2, "gpu": 1, "mem": 125}
52+
53+
self.out_hdf = self.output_path("out.hdf")
54+
55+
@classmethod
56+
def hash(cls, parsed_args):
57+
del parsed_args["dump_wav_first_n_seqs"]
58+
return super().hash(parsed_args)
59+
60+
def tasks(self):
61+
yield Task("run", rqmt=self.rqmt)
62+
63+
def run(self):
64+
import os
65+
import sys
66+
import time
67+
import math
68+
from dataclasses import dataclass
69+
70+
os.environ["HF_HUB_CACHE"] = "/<on_purpose_invalid_hf_hub_cache_dir>"
71+
72+
import i6_experiments
73+
74+
recipe_dir = os.path.dirname(os.path.dirname(i6_experiments.__file__))
75+
sys.path.insert(0, recipe_dir)
76+
77+
import i6_core.util as util
78+
79+
returnn_root = util.get_returnn_root(self.returnn_root)
80+
sys.path.insert(0, returnn_root.get_path())
81+
82+
print("Import Torch, Numpy...")
83+
start_time = time.time()
84+
85+
import numpy as np
86+
import torch
87+
88+
print(f"({time.time() - start_time} secs)")
89+
90+
from returnn.util import better_exchook
91+
from returnn.datasets.hdf import SimpleHDFWriter
92+
from i6_experiments.users.zeyer.torch.report_dev_memory_stats import report_dev_memory_stats
93+
94+
# os.environ["DEBUG"] = "1" # for better_exchook to use debug shell on error
95+
better_exchook.install()
96+
97+
try:
98+
# noinspection PyUnresolvedReferences
99+
import lovely_tensors
100+
101+
lovely_tensors.monkey_patch()
102+
except ImportError:
103+
pass
104+
105+
from .models import make_model, ForwardOutput
106+
107+
device_str = "cuda"
108+
dev = torch.device(device_str)
109+
110+
model_config = instanciate_delayed_copy(self.model_config)
111+
model = make_model(**model_config, device=dev)
112+
113+
for p in model.parameters():
114+
p.requires_grad = False
115+
116+
report_dev_memory_stats(dev)
117+
118+
# Write word start/end ranges per chunk, and the chunk audio sample start/end ranges.
119+
hdf_writer = SimpleHDFWriter(
120+
self.out_hdf.get_path(), dim=2, ndim=2, extra_type={"audio_chunk_start_end": (2, 2, "int32")}
121+
)
122+
123+
# Iter over data
124+
125+
from datasets import load_dataset
126+
127+
ds = load_dataset(get_content_dir_from_hub_cache_dir(self.dataset_dir))
128+
print(f"Dataset: {ds}")
129+
print("Dataset keys:", ds.keys())
130+
print("Using key:", self.dataset_key)
131+
print("Num seqs:", len(ds[self.dataset_key]))
132+
133+
for seq_idx, data in enumerate(ds[self.dataset_key]):
134+
audio = data["audio"]["array"]
135+
if not isinstance(audio, np.ndarray):
136+
audio = np.array(audio)
137+
samplerate = data["audio"]["sampling_rate"]
138+
chunk_size_samples = math.ceil(self.chunk_size_secs * samplerate)
139+
words: List[str] = data["word_detail"]["utterance"]
140+
transcription = " ".join(words)
141+
print(f"* Seq {seq_idx}, {audio.shape=}, {len(audio) / samplerate} secs, {samplerate=}, {transcription!r}")
142+
assert len(transcription.split(" ")) == len(words)
143+
144+
if seq_idx == 0:
145+
print(" data keys:", data.keys())
146+
147+
# First a loop to determine the corse-chunkwise segmentation:
148+
# For fixed chunks (partially overlapping), assign the most likely words.
149+
# Dyn programming, outer loop over chunks.
150+
151+
print("* Chunkwise segmenting...")
152+
153+
chunk_start_end: List[Tuple[int, int]] = [] # in samples
154+
cur_audio_start = 0 # in samples
155+
while True: # while not ended
156+
cur_audio_end = cur_audio_start + chunk_size_samples
157+
if cur_audio_end > len(audio):
158+
cur_audio_end = len(audio)
159+
if len(audio) - cur_audio_end <= 128 and self.chunk_overlap_secs == 0:
160+
# Skip to end. Avoids potential problems with too short chunks.
161+
cur_audio_end = len(audio)
162+
assert cur_audio_end > cur_audio_start
163+
assert cur_audio_end - cur_audio_start > 1 # require some min len
164+
chunk_start_end.append((cur_audio_start, cur_audio_end))
165+
if cur_audio_end >= len(audio):
166+
break # only break point here
167+
cur_audio_start = cur_audio_end - math.ceil(self.chunk_overlap_secs * samplerate)
168+
assert cur_audio_start >= 0
169+
170+
array: List[List[_Node]] = [] # [chunk_idx][rel word_idx]
171+
172+
# In the (S+1)*C grid (RNN-T style), but we are not filling all S+1 entries per chunk.
173+
@dataclass
174+
class _Node:
175+
chunk_idx: int # 0 <= c < C. the chunk we are in.
176+
word_idx: int # 0 <= s <= S. we have seen this many words so far, words[:s]
177+
log_prob: torch.Tensor # []. log prob of this node
178+
exit_log_prob: torch.Tensor # []. log_prob+exit (end_token_id). horizontal transition to next chunk
179+
word_log_prob: Optional[
180+
torch.Tensor
181+
] # []. log_prob+word (one or more labels). vertical transition to next word. (None if s==S)
182+
backpointer: Optional[_Node] # prev chunk, or prev word
183+
184+
for cur_chunk_idx, (cur_audio_start, cur_audio_end) in enumerate(chunk_start_end):
185+
if cur_chunk_idx == 0 or not self.word_start_heuristic:
186+
prev_array_word_idx = 0
187+
cur_word_start = 0
188+
else:
189+
# Heuristic. Look through last chunk, look out for best exit_log_prob
190+
prev_array_word_idx = int(
191+
torch.stack([node.exit_log_prob for node in array[cur_chunk_idx - 1]]).argmax().item()
192+
)
193+
cur_word_start = array[cur_chunk_idx - 1][prev_array_word_idx].word_idx
194+
cur_word_end = len(words) # Go to the end. Not so expensive...
195+
print(
196+
f"** Forwarding chunk {cur_chunk_idx} (out of {len(chunk_start_end)}),"
197+
f" {cur_audio_start / samplerate}:{cur_audio_end / samplerate} secs,"
198+
f" words {cur_word_start}:{cur_word_end} (out of {len(words)})"
199+
)
200+
assert cur_word_end > cur_word_start # need to fix heuristic if this fails...
201+
if cur_audio_end >= len(audio):
202+
assert cur_word_end == len(words) # need to overthink approx if this fails...
203+
204+
forward_output: ForwardOutput = model(
205+
raw_inputs=torch.tensor(audio[cur_audio_start:cur_audio_end]).unsqueeze(0),
206+
raw_inputs_sample_rate=samplerate,
207+
raw_input_seq_lens=torch.tensor([cur_audio_end - cur_audio_start]),
208+
raw_targets=[words[cur_word_start:cur_word_end]],
209+
raw_target_seq_lens=torch.tensor([cur_word_end - cur_word_start]),
210+
omitted_prev_context=torch.tensor([cur_word_start]),
211+
)
212+
213+
# Calculate log probs
214+
# logits = model.lm_head(last_out[:, dst_text_start - 1 :]) # [B,T-dst_text_start+1,V]
215+
# logits = logits.float()
216+
# log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [B,T-dst_text_start,V]
217+
218+
log_probs = ...
219+
220+
array.append([])
221+
assert len(array) == cur_chunk_idx + 1
222+
for w, (t0, t1) in enumerate(words_start_end + [(dst_text_end, dst_text_end + 1)]):
223+
score = model.score(forward_output=forward_output, raw_target_frame_index=w)
224+
word_idx = cur_word_start + w
225+
if word_idx < cur_word_end:
226+
word_log_prob = torch.sum(
227+
torch.stack([log_probs[0, t - dst_text_start][input_ids[0, t]] for t in range(t0, t1)])
228+
) # []
229+
else:
230+
word_log_prob = None
231+
exit_log_prob = log_probs[0, t0 - dst_text_start][end_token_id] # []
232+
if w == 0:
233+
# Add some penalty. For empty chunks, the prob is often overestimated.
234+
exit_log_prob += self.empty_exit_penalty
235+
prev_node_left, prev_node_below = None, None
236+
if w > 0:
237+
prev_node_below = array[cur_chunk_idx][-1]
238+
assert prev_node_below.word_idx == word_idx - 1
239+
if cur_chunk_idx > 0 and prev_array_word_idx + w < len(array[cur_chunk_idx - 1]):
240+
prev_node_left = array[cur_chunk_idx - 1][prev_array_word_idx + w]
241+
assert prev_node_left.word_idx == word_idx
242+
if prev_node_below and not prev_node_left:
243+
prev_node = prev_node_below
244+
log_prob = prev_node_below.word_log_prob
245+
elif not prev_node_below and prev_node_left:
246+
prev_node = prev_node_left
247+
log_prob = prev_node_left.exit_log_prob
248+
elif prev_node_below and prev_node_left:
249+
if prev_node_below.word_log_prob >= prev_node_left.exit_log_prob:
250+
prev_node = prev_node_below
251+
log_prob = prev_node_below.word_log_prob
252+
else:
253+
prev_node = prev_node_left
254+
log_prob = prev_node_left.exit_log_prob
255+
else:
256+
assert cur_chunk_idx == word_idx == 0
257+
prev_node = None
258+
log_prob = torch.zeros(())
259+
array[cur_chunk_idx].append(
260+
_Node(
261+
chunk_idx=cur_chunk_idx,
262+
word_idx=word_idx,
263+
log_prob=log_prob,
264+
backpointer=prev_node,
265+
word_log_prob=(log_prob + word_log_prob) if word_idx < cur_word_end else None,
266+
exit_log_prob=log_prob + exit_log_prob,
267+
)
268+
)
269+
assert (
270+
len(array[cur_chunk_idx]) == cur_word_end - cur_word_start + 1
271+
and array[cur_chunk_idx][0].word_idx == cur_word_start
272+
and array[cur_chunk_idx][-1].word_idx == cur_word_end
273+
)
274+
275+
del forward_output, log_probs # not needed anymore now
276+
277+
# Backtrack
278+
nodes_alignment: List[_Node] = []
279+
node = array[-1][-1]
280+
assert node.word_idx == len(words) # has seen all words
281+
while node:
282+
nodes_alignment.append(node)
283+
node = node.backpointer
284+
nodes_alignment.reverse()
285+
286+
# Collect words per chunk
287+
words_per_chunks: List[List[int]] = [[] for _ in range(len(chunk_start_end))]
288+
words_covered = 0
289+
for node in nodes_alignment[1:]:
290+
if node.backpointer.chunk_idx == node.chunk_idx:
291+
assert node.word_idx == node.backpointer.word_idx + 1
292+
words_per_chunks[node.chunk_idx].append(node.word_idx - 1)
293+
assert words_covered == node.word_idx - 1
294+
words_covered += 1
295+
else:
296+
assert node.chunk_idx == node.backpointer.chunk_idx + 1
297+
assert node.word_idx == node.backpointer.word_idx
298+
assert words_covered == len(words)
299+
words_indices_start_end = [(ws[0], ws[-1] + 1) if ws else (-1, -1) for ws in words_per_chunks]
300+
print(" Words per chunks:", words_indices_start_end)
301+
302+
assert len(words_indices_start_end) == len(chunk_start_end)
303+
hdf_writer.insert_batch(
304+
np.array(words_indices_start_end)[None],
305+
seq_len=[len(chunk_start_end)],
306+
seq_tag=[f"seq-{seq_idx}"],
307+
extra={"audio_chunk_start_end": np.array(chunk_start_end)[None]},
308+
)
309+
310+
if seq_idx < self.dump_wav_first_n_seqs:
311+
for cur_chunk_idx, ((cur_audio_start, cur_audio_end), ws) in enumerate(
312+
zip(chunk_start_end, words_per_chunks)
313+
):
314+
write_wave_file(
315+
f"seq{seq_idx}-chunk{cur_chunk_idx}.wav",
316+
samples=audio[cur_audio_start:cur_audio_end],
317+
sr=samplerate,
318+
)
319+
with open(f"seq{seq_idx}-chunk{cur_chunk_idx}.txt", "w") as f:
320+
f.write(" ".join(words[w] for w in ws))
321+
322+
hdf_writer.close()
323+
324+
# better_exchook.debug_shell(user_ns=locals(), user_global_ns=locals())

0 commit comments

Comments
 (0)