Skip to content

Commit c52f53a

Browse files
authored
Update diarize.py
Fix some bugs and add some parameters
1 parent 7d4a611 commit c52f53a

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

scripts/diarize.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,37 @@
99
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization",
1010
use_auth_token=AUTH_TOKEN)
1111

12-
13-
def diarize_audio(audio_path, out_dir, num_speakers=2, keep_turn=True):
12+
def diarize_audio(audio_path, out_dir=None, num_speakers=None, keep_turn=False, min_sec=0.5, max_sec=None):
1413
sr, audio = wavfile.read(audio_path)
1514
diarization = pipeline(audio_path, num_speakers=num_speakers)
16-
17-
out_dir = os.path.splitext(audio_path)[0]
18-
os.makedirs(out_dir, exist_ok=True)
19-
15+
2016
start_frames, end_frames = None, None
2117
last_spk = None
2218
i = 0
2319
for turn, _, speaker in diarization.itertracks(yield_label=True):
2420
spk = speaker
21+
if out_dir is None:
22+
out_dir = spk
23+
os.makedirs(out_dir, exist_ok=True)
24+
2525
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker: {spk}")
2626

2727
if keep_turn:
28-
28+
if not start_frames:
29+
start_frames = int(turn.start)
2930
if not last_spk:
3031
last_spk = spk
3132
if spk == last_spk:
3233
end_frames = int(sr*turn.end)
3334
else:
3435
i+=1
36+
if min_sec is not None and (end_frames - start_frames)/sr < min_sec:
37+
print(f"skipping {turn.start:.1f}s stop={turn.end:.1f} because it is too short")
38+
continue
39+
if max_sec is not None and (end_frames - start_frames)/sr > max_sec:
40+
print(f"skipping {turn.start:.1f}s stop={turn.end:.1f} because it is too long")
41+
continue
42+
3543
wavfile.write(os.path.join(out_dir, f"{i:04}-{last_spk}.wav"), sr, audio[start_frames:end_frames])
3644

3745
last_spk = spk
@@ -40,14 +48,15 @@ def diarize_audio(audio_path, out_dir, num_speakers=2, keep_turn=True):
4048
else:
4149
wavfile.write(os.path.join(out_dir, f"{i:04}-{spk}.wav"), sr, audio[int(sr*turn.start):int(sr*turn.end)])
4250
i+=1
43-
51+
4452

4553
if __name__ == "__main__":
4654
import argparse
47-
parser = argparse.ArgumentParser(help="Diarize audio file")
55+
parser = argparse.ArgumentParser("Diarize audio file")
4856
parser.add_argument("audio_path", type=str, help="Path to audio file")
49-
parser.add_argument("out_dir", type=str, help="Path to output directory")
57+
parser.add_argument("--min-sec", type=float, default=0.5)
58+
parser.add_argument("--max-sec", type=float, default=None)
5059
parser.add_argument("--num_speakers", type=int, default=2, help="Number of speakers")
5160
args = parser.parse_args()
5261

53-
diarize_audio(args.audio_path, args.out_dir, args.num_speakers, keep_turn=True)
62+
diarize_audio(args.audio_path, args.out_dir, args.num_speakers, min_sec=args.min_sec, max_sec=args.max_sec, keep_turn=True)

0 commit comments

Comments
 (0)