Skip to content

Commit 88644c4

Browse files
jesseengelMagenta Team
authored and
Magenta Team
committed
Open source URMP dataset pipeline
PiperOrigin-RevId: 460754854
1 parent 915b35c commit 88644c4

File tree

5 files changed

+408
-1
lines changed

5 files changed

+408
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2022 The DDSP Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
r"""Prepare URMP dataset DDSP and NoteSequence features.
16+
17+
Usage:
18+
====================
19+
ddsp_prepare_urmp_dataset \
20+
--input_filepath='/path/to/input.tfrecord-*' \
21+
--output_filepath='/path/to/output.tfrecord' \
22+
--instrument_key=vn \
23+
--num_shards=10 \
24+
--alsologtostderr
25+
26+
"""
27+
28+
from absl import app
29+
from absl import flags
30+
31+
from ddsp.training.data_preparation.prepare_urmp_dataset_lib import prepare_urmp
32+
import tensorflow.compat.v2 as tf
33+
34+
FLAGS = flags.FLAGS
35+
36+
flags.DEFINE_string('input_filepath', '', 'Input filepath for dataset.')
37+
flags.DEFINE_string('output_filepath', '', 'Output filepath for dataset.')
38+
flags.DEFINE_multi_string(
39+
'instrument_key', [], 'Instrument keys to extract. '
40+
'If not set, extract all instruments. Possible keys '
41+
'are vn, va, vc, db, fl, ob, cl, sax, bn, tpt, hn, '
42+
'tbn, tba.')
43+
flags.DEFINE_integer(
44+
'num_shards', None, 'Num shards for output dataset. If '
45+
'None, this number will be determined automatically.')
46+
flags.DEFINE_bool('batch', True, 'Whether or not to batch the dataset.')
47+
flags.DEFINE_bool('force_monophonic', True, 'Fix URMP note labels such that '
48+
'note onsets and offsets do not overlap.')
49+
flags.DEFINE_list(
50+
'pipeline_options', '--runner=DirectRunner',
51+
'A comma-separated list of command line arguments to be used as options '
52+
'for the Beam Pipeline.')
53+
flags.DEFINE_integer('ddsp_sample_rate', 250, 'Sample rate for dataset output.')
54+
flags.DEFINE_integer('audio_sample_rate', 16000, 'Sample rate for URMP audio.')
55+
56+
57+
def run():
58+
prepare_urmp(
59+
input_filepath=FLAGS.input_filepath,
60+
output_filepath=FLAGS.output_filepath,
61+
instrument_keys=FLAGS.instrument_key,
62+
num_shards=FLAGS.num_shards,
63+
batch=FLAGS.batch,
64+
force_monophonic=FLAGS.force_monophonic,
65+
pipeline_options=FLAGS.pipeline_options,
66+
ddsp_sample_rate=FLAGS.ddsp_sample_rate,
67+
audio_sample_rate=FLAGS.audio_sample_rate)
68+
69+
70+
def main(unused_argv):
71+
"""From command line."""
72+
run()
73+
74+
75+
def console_entry_point():
76+
"""From pip installed script."""
77+
app.run(main)
78+
79+
80+
if __name__ == '__main__':
81+
console_entry_point()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright 2022 The DDSP Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""URMP data import pipeline."""
16+
import apache_beam as beam
17+
import ddsp
18+
from ddsp.training import heuristics
19+
from mir_eval import melody
20+
from note_seq import audio_io
21+
from note_seq import constants
22+
from note_seq import sequences_lib
23+
from note_seq.protobuf import music_pb2
24+
import numpy as np
25+
import tensorflow as tf
26+
27+
28+
DDSP_SAMPLE_RATE = 250
29+
AUDIO_SAMPLE_RATE = 16000
30+
31+
32+
def parse_example(tfexample):
33+
"""Parse tf.Example protos to dict of numpy arrays."""
34+
features = {
35+
'id':
36+
tf.io.FixedLenFeature([], dtype=tf.string),
37+
'audio':
38+
tf.io.FixedLenFeature([], dtype=tf.string),
39+
'f0_hz':
40+
tf.io.FixedLenSequenceFeature([],
41+
dtype=tf.float32,
42+
allow_missing=True),
43+
'f0_time':
44+
tf.io.FixedLenSequenceFeature([],
45+
dtype=tf.float32,
46+
allow_missing=True),
47+
'sequence':
48+
tf.io.FixedLenFeature([], dtype=tf.string)
49+
}
50+
ex = {
51+
key: val.numpy()
52+
for key, val in tf.io.parse_single_example(tfexample, features).items()
53+
}
54+
return ex
55+
56+
57+
def get_active_frame_indices(piano_roll):
58+
"""Create matrix of frame indices for active notes relative to onset."""
59+
active_frame_indices = np.zeros_like(piano_roll.active_velocities)
60+
for frame_i in range(1, active_frame_indices.shape[0]):
61+
prev_indices = active_frame_indices[frame_i - 1, :]
62+
active_notes = piano_roll.active[frame_i, :]
63+
active_frame_indices[frame_i, :] = (prev_indices + 1) * active_notes
64+
return active_frame_indices
65+
66+
67+
def attach_metadata(ex, ddsp_sample_rate, audio_sample_rate, force_monophonic):
68+
"""Parse and attach metadata from the dataset."""
69+
70+
def extract_recording_id(id_string):
71+
id_string = id_string.split(b'/')[-1]
72+
id_string = id_string.split(b'.')[0]
73+
return id_string
74+
75+
def extract_instrument_id(id_string):
76+
id_string = extract_recording_id(id_string).split(b'_')
77+
return id_string[2]
78+
79+
def extract_notes(sequence_str, expected_seconds):
80+
ns = music_pb2.NoteSequence.FromString(sequence_str)
81+
# total time in dataset doesn't include silence at the end
82+
if force_monophonic:
83+
for i in range(1, len(ns.notes)):
84+
note = ns.notes[i]
85+
prev_note = ns.notes[i - 1]
86+
onset_frame = int(note.start_time * ddsp_sample_rate)
87+
prev_note_offset_frame = int(prev_note.end_time * ddsp_sample_rate)
88+
if prev_note_offset_frame >= onset_frame:
89+
frames_to_move = (prev_note_offset_frame - onset_frame) + 1
90+
# move previous note's onset back by frames_to_move frames in seconds
91+
prev_note.end_time -= float(frames_to_move) / ddsp_sample_rate
92+
93+
ns.total_time = expected_seconds
94+
piano_roll = sequences_lib.sequence_to_pianoroll(
95+
ns,
96+
frames_per_second=ddsp_sample_rate,
97+
min_pitch=constants.MIN_MIDI_PITCH,
98+
max_pitch=constants.MAX_MIDI_PITCH,
99+
onset_mode='length_ms')
100+
101+
note_dict = {
102+
'note_active_velocities': piano_roll.active_velocities,
103+
'note_active_frame_indices': get_active_frame_indices(piano_roll),
104+
'note_onsets': piano_roll.onsets,
105+
'note_offsets': piano_roll.offsets
106+
}
107+
108+
return note_dict
109+
110+
ex['recording_id'] = extract_recording_id(ex['id'])
111+
ex['instrument_id'] = extract_instrument_id(ex['id'])
112+
ex['audio'] = audio_io.wav_data_to_samples_librosa(
113+
ex['audio'], sample_rate=audio_sample_rate)
114+
expected_seconds = ex['audio'].shape[0] / audio_sample_rate
115+
ex.update(extract_notes(ex['sequence'], expected_seconds))
116+
beam.metrics.Metrics.distribution('prepare-urmp',
117+
'orig-audio-len').update(len(ex['audio']))
118+
return ex
119+
120+
121+
def normalize_audio(ex, max_audio):
122+
ex['audio'] /= max_audio
123+
return ex
124+
125+
126+
def resample(ex, ddsp_sample_rate, audio_sample_rate):
127+
"""Resample features to standard DDSP sample rate."""
128+
f0_times = ex['f0_time']
129+
f0_orig = ex['f0_hz']
130+
max_time = np.max(f0_times)
131+
new_times = np.linspace(0, max_time, int(ddsp_sample_rate * max_time))
132+
if f0_times[0] > 0:
133+
f0_orig = np.insert(f0_orig, 0, f0_orig[0])
134+
f0_times = np.insert(f0_times, 0, 0)
135+
f0_interpolated, _ = melody.resample_melody_series(
136+
f0_times, f0_orig,
137+
melody.freq_to_voicing(f0_orig)[1], new_times)
138+
ex['f0_hz'] = f0_interpolated
139+
ex['f0_time'] = new_times
140+
ex['orig_f0_hz'] = f0_orig
141+
ex['orig_f0_time'] = f0_times
142+
143+
# Truncate audio to an integer multiple of f0_hz vector.
144+
num_audio_samples = round(
145+
len(ex['f0_hz']) * (audio_sample_rate / ddsp_sample_rate))
146+
beam.metrics.Metrics.distribution(
147+
'prepare-urmp',
148+
'resampled-audio-diff').update(num_audio_samples - len(ex['audio']))
149+
150+
ex['audio'] = ex['audio'][:num_audio_samples]
151+
152+
# Truncate pianoroll features to length of f0_hz vector.
153+
for key in [
154+
'note_active_frame_indices', 'note_active_velocities', 'note_onsets',
155+
'note_offsets'
156+
]:
157+
ex[key] = ex[key][:len(ex['f0_hz']), :]
158+
159+
return ex
160+
161+
162+
def batch_dataset(ex, audio_sample_rate, ddsp_sample_rate):
163+
"""Split features and audio into 4 second sliding windows."""
164+
batched = []
165+
for key, vec in ex.items():
166+
if isinstance(vec, np.ndarray):
167+
if key == 'audio':
168+
sampling_rate = audio_sample_rate
169+
else:
170+
sampling_rate = ddsp_sample_rate
171+
172+
frames = heuristics.window_array(vec, sampling_rate, 4.0, 0.25)
173+
if not batched:
174+
batched = [{} for _ in range(len(frames))]
175+
for i, frame in enumerate(frames):
176+
batched[i][key] = frame
177+
178+
# once batches are created, replicate ids and metadata over all elements.
179+
for key, val in ex.items():
180+
if not isinstance(val, np.ndarray):
181+
for batch in batched:
182+
batch[key] = val
183+
184+
beam.metrics.Metrics.counter('prepare-urmp',
185+
'batches-created').inc(len(batched))
186+
return batched
187+
188+
189+
def attach_ddsp_features(ex):
190+
ex['loudness_db'] = ddsp.spectral_ops.compute_loudness(ex['audio'])
191+
ex['power_db'] = ddsp.spectral_ops.compute_power(ex['audio'], frame_size=256)
192+
# ground truth annotations are set with confidence 1.0
193+
ex['f0_confidence'] = np.ones_like(ex['f0_hz'])
194+
beam.metrics.Metrics.counter('prepare-urmp', 'ddsp-features-attached').inc()
195+
return ex
196+
197+
198+
def serialize_tfexample(ex):
199+
"""Creates a tf.Example message ready to be written to a file."""
200+
201+
def _feature(arr):
202+
"""Returns a feature from a numpy array or string."""
203+
if isinstance(arr, (bytes, str)):
204+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[arr]))
205+
else:
206+
arr = np.asarray(arr).reshape(-1)
207+
return tf.train.Feature(float_list=tf.train.FloatList(value=arr))
208+
209+
# Create a dictionary mapping the feature name to the tf.Example-compatible
210+
# data type.
211+
feature = {k: _feature(v) for k, v in ex.items()}
212+
213+
# Create a Features message using tf.train.Example.
214+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
215+
return example_proto
216+
217+
218+
def prepare_urmp(input_filepath,
219+
output_filepath,
220+
instrument_keys,
221+
num_shards,
222+
batch,
223+
force_monophonic,
224+
pipeline_options,
225+
ddsp_sample_rate=DDSP_SAMPLE_RATE,
226+
audio_sample_rate=AUDIO_SAMPLE_RATE):
227+
"""Pipeline for parsing URMP dataset to a usable format for DDSP."""
228+
pipeline_options = beam.options.pipeline_options.PipelineOptions(
229+
pipeline_options)
230+
with beam.Pipeline(options=pipeline_options) as pipeline:
231+
examples = (
232+
pipeline
233+
|
234+
'read_tfrecords' >> beam.io.tfrecordio.ReadFromTFRecord(input_filepath)
235+
| 'parse_example' >> beam.Map(parse_example)
236+
| 'attach_metadata' >> beam.Map(
237+
attach_metadata,
238+
ddsp_sample_rate=ddsp_sample_rate,
239+
audio_sample_rate=audio_sample_rate,
240+
force_monophonic=force_monophonic))
241+
242+
if instrument_keys:
243+
examples |= 'filter_instruments' >> beam.Filter(
244+
lambda ex: ex['instrument_id'].decode() in instrument_keys)
245+
246+
examples |= 'resample' >> beam.Map(
247+
resample,
248+
ddsp_sample_rate=ddsp_sample_rate,
249+
audio_sample_rate=audio_sample_rate)
250+
if batch:
251+
examples |= 'batch' >> beam.FlatMap(
252+
batch_dataset,
253+
audio_sample_rate=audio_sample_rate,
254+
ddsp_sample_rate=ddsp_sample_rate)
255+
_ = (
256+
examples
257+
| 'attach_ddsp_features' >> beam.Map(attach_ddsp_features)
258+
| 'filter_silence' >>
259+
beam.Filter(lambda ex: np.any(ex['loudness_db'] > -70))
260+
| 'serialize_tfexamples' >> beam.Map(serialize_tfexample)
261+
| 'shuffle' >> beam.Reshuffle()
262+
| beam.io.tfrecordio.WriteToTFRecord(
263+
output_filepath,
264+
num_shards=num_shards,
265+
coder=beam.coders.ProtoCoder(tf.train.Example)))

0 commit comments

Comments
 (0)