Skip to content

Commit deccded

Browse files
committed
Add functioning viz tool
1 parent 3a76a83 commit deccded

File tree

3 files changed

+106
-51
lines changed

3 files changed

+106
-51
lines changed

headbang/consensus.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
algo_names = [
18-
"_", # dummy at index 0 because in this code, beat trackers start at 1: "1,2,3...8"
18+
"_", # dummy at index 0 because in this code, beat trackers start at 1: "1,2,3...8"
1919
"madmom DBNBeatTrackingProcessor",
2020
"madmom BeatDetectionProcessor",
2121
"essentia BeatTrackerMultiFeature",

headbang/viz_tool.py

+104-50
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import argparse
77
import sys
88
import librosa
9+
import pandas as pd
910
import gc
1011
import os
1112
import multiprocessing
13+
from collections import OrderedDict
1214
from moviepy.editor import *
1315
from moviepy.audio.AudioClip import AudioArrayClip
1416
from tempfile import gettempdir
@@ -19,10 +21,18 @@
1921
from madmom.io.audio import write_wave_file
2022

2123

24+
def find_closest(A, target):
25+
# A must be sorted
26+
idx = A.searchsorted(target)
27+
idx = numpy.clip(idx, 1, len(A) - 1)
28+
left = A[idx - 1]
29+
right = A[idx]
30+
idx -= target - left < right - target
31+
return idx
32+
33+
2234
def main():
23-
parser = argparse.ArgumentParser(
24-
description="Vizualize the headbang beat tracker"
25-
)
35+
parser = argparse.ArgumentParser(description="Vizualize the headbang beat tracker")
2636

2737
parser.add_argument("wav_in", type=str, help="wav file to process")
2838
parser.add_argument("mp4_out", type=str, help="mp4 output path")
@@ -41,32 +51,17 @@ def main():
4151
# get the inner multi-beat-tracker list from headbangbeattracker's consensusbeattracker object
4252
individual_tracker_beat_locations = hbt.cbt.beat_results
4353

44-
# blue, yellow, magenta, violet, orange, brown, white
45-
colors = itertools.cycle([(0, 165, 255), (255, 0, 255), (255, 255, 0), (255, 69, 0), (0, 255, 255), (165, 42, 42), (255, 255, 255)])
46-
47-
beat_trackers = {
48-
'headbang': {
49-
'beats': strong_beat_locations,
50-
'color': (255, 0, 0), # red
51-
},
52-
'consensus': {
53-
'beats': hbt.beat_consensus,
54-
'color': (0, 255, 0), # lime green
55-
},
56-
}
57-
for i, algo_name in enumerate(algo_names[1:]):
58-
beat_trackers[algo_name] = {
59-
'beats': individual_tracker_beat_locations[i],
60-
'color': next(colors)
61-
}
62-
63-
beat_trackers['onsets'] = {
64-
'beats': hbt.onsets,
65-
'color': next(colors)
66-
}
67-
68-
for name, bt in beat_trackers.items():
69-
print('{0}: {1}'.format(name, bt))
54+
colors = [
55+
(255, 0, 0), # red
56+
(0, 255, 0), # green
57+
(0, 165, 255), # blue
58+
(255, 0, 255), # magenta
59+
(255, 255, 0), # yellow
60+
(255, 69, 0), # orange
61+
(0, 255, 255), # cyan
62+
(145, 112, 235), # blue-violet
63+
]
64+
colorcycle = itertools.cycle(colors)
7065

7166
fps = 30
7267

@@ -76,34 +71,93 @@ def main():
7671
frame_duration = 1 / fps
7772
frame_duration_ms = frame_duration * 1000
7873

79-
total_duration = float(audio.shape[0])/44100.0
80-
total_frames = total_duration/frame_duration
74+
total_duration = numpy.floor(float(audio.shape[0]) / 44100.0)
75+
total_frames = int(numpy.ceil(total_duration / frame_duration))
76+
77+
times_vector = numpy.arange(0, total_duration, frame_duration)
78+
79+
all_beat_times = individual_tracker_beat_locations + [
80+
strong_beat_locations,
81+
hbt.beat_consensus,
82+
]
83+
84+
all_beat_frames = [
85+
numpy.concatenate(
86+
(
87+
numpy.zeros(
88+
1,
89+
),
90+
find_closest(times_vector, beat_times),
91+
numpy.ones(
92+
1,
93+
)
94+
* (total_frames - 1),
95+
)
96+
).astype(numpy.int)
97+
for beat_times in all_beat_times
98+
]
99+
100+
off_beat_frames = [
101+
((x[1:] + x[:-1]) / 2).astype(numpy.int) for x in all_beat_frames
102+
]
103+
104+
all_positions = [] # []
105+
for i in range(len(all_beat_frames)):
106+
x = (
107+
numpy.empty(
108+
total_frames,
109+
)
110+
* numpy.nan
111+
)
81112

82-
total_duration = frame_duration * total_frames
113+
x[all_beat_frames[i]] = 1
114+
x[off_beat_frames[i]] = -1
115+
a = pd.Series(x)
116+
all_positions.append(a.interpolate().to_numpy())
83117

84118
blank_frame = numpy.zeros((video_height, video_width, 3), numpy.uint8)
85119

86-
def render_animations(*args, **kwargs):
87-
video_frame = blank_frame.copy()
120+
box_width = int(video_width / 4)
121+
box_edges_horiz = numpy.arange(0, video_width + 1, box_width)
122+
box_centers_horiz = box_edges_horiz[:-1] + int(box_width / 2)
88123

89-
# draw stick figures with text
90-
91-
# draw some text, names of algorithms etc.
92-
cv2.putText(
93-
video_frame,
94-
"BEAT",
95-
all_beat_pos,
96-
cv2.FONT_HERSHEY_SIMPLEX,
97-
2.0,
98-
all_beat_color,
99-
3,
100-
cv2.LINE_AA,
101-
)
102-
cv2.line(image, (20,10), (100,10), (255,0,0), 2)
124+
box_height = int(video_height / 2)
125+
box_edges_vert = numpy.arange(0, video_height + 1, box_height)
126+
box_centers_vert = box_edges_vert[:-1] + int(box_height / 2)
127+
128+
positions = list(itertools.product(box_centers_horiz, box_centers_vert))
129+
130+
frame_index = 0
103131

132+
def render_animations(*args, **kwargs):
133+
nonlocal frame_index
134+
video_frame = blank_frame.copy()
104135

105-
# adjust color on frames
106-
video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)
136+
for i, beats in enumerate(all_beat_frames):
137+
center = positions[i]
138+
try:
139+
interpolated_pos = all_positions[i][frame_index]
140+
except IndexError:
141+
interpolated_pos = 0
142+
143+
current_position = (
144+
center[0],
145+
int(center[1] + (box_height / 2 - 100) * interpolated_pos),
146+
)
147+
148+
# draw some text, names of algorithms etc.
149+
cv2.putText(
150+
video_frame,
151+
str(i),
152+
current_position,
153+
cv2.FONT_HERSHEY_SIMPLEX,
154+
2.0,
155+
colors[i],
156+
3,
157+
cv2.LINE_AA,
158+
)
159+
160+
frame_index += 1
107161
return video_frame
108162

109163
print("Processing video - rendering animations")

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ moviepy==1.0.3
1111
essentia==2.1b6.dev374
1212
madmom==0.16.1
1313
scipy==1.6.0
14+
pandas

0 commit comments

Comments
 (0)