|
| 1 | +# Copyright 2022 The DDSP Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""URMP data import pipeline.""" |
| 16 | +import apache_beam as beam |
| 17 | +import ddsp |
| 18 | +from ddsp.training import heuristics |
| 19 | +from mir_eval import melody |
| 20 | +from note_seq import audio_io |
| 21 | +from note_seq import constants |
| 22 | +from note_seq import sequences_lib |
| 23 | +from note_seq.protobuf import music_pb2 |
| 24 | +import numpy as np |
| 25 | +import tensorflow as tf |
| 26 | + |
| 27 | + |
| 28 | +DDSP_SAMPLE_RATE = 250 |
| 29 | +AUDIO_SAMPLE_RATE = 16000 |
| 30 | + |
| 31 | + |
| 32 | +def parse_example(tfexample): |
| 33 | + """Parse tf.Example protos to dict of numpy arrays.""" |
| 34 | + features = { |
| 35 | + 'id': |
| 36 | + tf.io.FixedLenFeature([], dtype=tf.string), |
| 37 | + 'audio': |
| 38 | + tf.io.FixedLenFeature([], dtype=tf.string), |
| 39 | + 'f0_hz': |
| 40 | + tf.io.FixedLenSequenceFeature([], |
| 41 | + dtype=tf.float32, |
| 42 | + allow_missing=True), |
| 43 | + 'f0_time': |
| 44 | + tf.io.FixedLenSequenceFeature([], |
| 45 | + dtype=tf.float32, |
| 46 | + allow_missing=True), |
| 47 | + 'sequence': |
| 48 | + tf.io.FixedLenFeature([], dtype=tf.string) |
| 49 | + } |
| 50 | + ex = { |
| 51 | + key: val.numpy() |
| 52 | + for key, val in tf.io.parse_single_example(tfexample, features).items() |
| 53 | + } |
| 54 | + return ex |
| 55 | + |
| 56 | + |
| 57 | +def get_active_frame_indices(piano_roll): |
| 58 | + """Create matrix of frame indices for active notes relative to onset.""" |
| 59 | + active_frame_indices = np.zeros_like(piano_roll.active_velocities) |
| 60 | + for frame_i in range(1, active_frame_indices.shape[0]): |
| 61 | + prev_indices = active_frame_indices[frame_i - 1, :] |
| 62 | + active_notes = piano_roll.active[frame_i, :] |
| 63 | + active_frame_indices[frame_i, :] = (prev_indices + 1) * active_notes |
| 64 | + return active_frame_indices |
| 65 | + |
| 66 | + |
| 67 | +def attach_metadata(ex, ddsp_sample_rate, audio_sample_rate, force_monophonic): |
| 68 | + """Parse and attach metadata from the dataset.""" |
| 69 | + |
| 70 | + def extract_recording_id(id_string): |
| 71 | + id_string = id_string.split(b'/')[-1] |
| 72 | + id_string = id_string.split(b'.')[0] |
| 73 | + return id_string |
| 74 | + |
| 75 | + def extract_instrument_id(id_string): |
| 76 | + id_string = extract_recording_id(id_string).split(b'_') |
| 77 | + return id_string[2] |
| 78 | + |
| 79 | + def extract_notes(sequence_str, expected_seconds): |
| 80 | + ns = music_pb2.NoteSequence.FromString(sequence_str) |
| 81 | + # total time in dataset doesn't include silence at the end |
| 82 | + if force_monophonic: |
| 83 | + for i in range(1, len(ns.notes)): |
| 84 | + note = ns.notes[i] |
| 85 | + prev_note = ns.notes[i - 1] |
| 86 | + onset_frame = int(note.start_time * ddsp_sample_rate) |
| 87 | + prev_note_offset_frame = int(prev_note.end_time * ddsp_sample_rate) |
| 88 | + if prev_note_offset_frame >= onset_frame: |
| 89 | + frames_to_move = (prev_note_offset_frame - onset_frame) + 1 |
| 90 | + # move previous note's onset back by frames_to_move frames in seconds |
| 91 | + prev_note.end_time -= float(frames_to_move) / ddsp_sample_rate |
| 92 | + |
| 93 | + ns.total_time = expected_seconds |
| 94 | + piano_roll = sequences_lib.sequence_to_pianoroll( |
| 95 | + ns, |
| 96 | + frames_per_second=ddsp_sample_rate, |
| 97 | + min_pitch=constants.MIN_MIDI_PITCH, |
| 98 | + max_pitch=constants.MAX_MIDI_PITCH, |
| 99 | + onset_mode='length_ms') |
| 100 | + |
| 101 | + note_dict = { |
| 102 | + 'note_active_velocities': piano_roll.active_velocities, |
| 103 | + 'note_active_frame_indices': get_active_frame_indices(piano_roll), |
| 104 | + 'note_onsets': piano_roll.onsets, |
| 105 | + 'note_offsets': piano_roll.offsets |
| 106 | + } |
| 107 | + |
| 108 | + return note_dict |
| 109 | + |
| 110 | + ex['recording_id'] = extract_recording_id(ex['id']) |
| 111 | + ex['instrument_id'] = extract_instrument_id(ex['id']) |
| 112 | + ex['audio'] = audio_io.wav_data_to_samples_librosa( |
| 113 | + ex['audio'], sample_rate=audio_sample_rate) |
| 114 | + expected_seconds = ex['audio'].shape[0] / audio_sample_rate |
| 115 | + ex.update(extract_notes(ex['sequence'], expected_seconds)) |
| 116 | + beam.metrics.Metrics.distribution('prepare-urmp', |
| 117 | + 'orig-audio-len').update(len(ex['audio'])) |
| 118 | + return ex |
| 119 | + |
| 120 | + |
| 121 | +def normalize_audio(ex, max_audio): |
| 122 | + ex['audio'] /= max_audio |
| 123 | + return ex |
| 124 | + |
| 125 | + |
| 126 | +def resample(ex, ddsp_sample_rate, audio_sample_rate): |
| 127 | + """Resample features to standard DDSP sample rate.""" |
| 128 | + f0_times = ex['f0_time'] |
| 129 | + f0_orig = ex['f0_hz'] |
| 130 | + max_time = np.max(f0_times) |
| 131 | + new_times = np.linspace(0, max_time, int(ddsp_sample_rate * max_time)) |
| 132 | + if f0_times[0] > 0: |
| 133 | + f0_orig = np.insert(f0_orig, 0, f0_orig[0]) |
| 134 | + f0_times = np.insert(f0_times, 0, 0) |
| 135 | + f0_interpolated, _ = melody.resample_melody_series( |
| 136 | + f0_times, f0_orig, |
| 137 | + melody.freq_to_voicing(f0_orig)[1], new_times) |
| 138 | + ex['f0_hz'] = f0_interpolated |
| 139 | + ex['f0_time'] = new_times |
| 140 | + ex['orig_f0_hz'] = f0_orig |
| 141 | + ex['orig_f0_time'] = f0_times |
| 142 | + |
| 143 | + # Truncate audio to an integer multiple of f0_hz vector. |
| 144 | + num_audio_samples = round( |
| 145 | + len(ex['f0_hz']) * (audio_sample_rate / ddsp_sample_rate)) |
| 146 | + beam.metrics.Metrics.distribution( |
| 147 | + 'prepare-urmp', |
| 148 | + 'resampled-audio-diff').update(num_audio_samples - len(ex['audio'])) |
| 149 | + |
| 150 | + ex['audio'] = ex['audio'][:num_audio_samples] |
| 151 | + |
| 152 | + # Truncate pianoroll features to length of f0_hz vector. |
| 153 | + for key in [ |
| 154 | + 'note_active_frame_indices', 'note_active_velocities', 'note_onsets', |
| 155 | + 'note_offsets' |
| 156 | + ]: |
| 157 | + ex[key] = ex[key][:len(ex['f0_hz']), :] |
| 158 | + |
| 159 | + return ex |
| 160 | + |
| 161 | + |
| 162 | +def batch_dataset(ex, audio_sample_rate, ddsp_sample_rate): |
| 163 | + """Split features and audio into 4 second sliding windows.""" |
| 164 | + batched = [] |
| 165 | + for key, vec in ex.items(): |
| 166 | + if isinstance(vec, np.ndarray): |
| 167 | + if key == 'audio': |
| 168 | + sampling_rate = audio_sample_rate |
| 169 | + else: |
| 170 | + sampling_rate = ddsp_sample_rate |
| 171 | + |
| 172 | + frames = heuristics.window_array(vec, sampling_rate, 4.0, 0.25) |
| 173 | + if not batched: |
| 174 | + batched = [{} for _ in range(len(frames))] |
| 175 | + for i, frame in enumerate(frames): |
| 176 | + batched[i][key] = frame |
| 177 | + |
| 178 | + # once batches are created, replicate ids and metadata over all elements. |
| 179 | + for key, val in ex.items(): |
| 180 | + if not isinstance(val, np.ndarray): |
| 181 | + for batch in batched: |
| 182 | + batch[key] = val |
| 183 | + |
| 184 | + beam.metrics.Metrics.counter('prepare-urmp', |
| 185 | + 'batches-created').inc(len(batched)) |
| 186 | + return batched |
| 187 | + |
| 188 | + |
| 189 | +def attach_ddsp_features(ex): |
| 190 | + ex['loudness_db'] = ddsp.spectral_ops.compute_loudness(ex['audio']) |
| 191 | + ex['power_db'] = ddsp.spectral_ops.compute_power(ex['audio'], frame_size=256) |
| 192 | + # ground truth annotations are set with confidence 1.0 |
| 193 | + ex['f0_confidence'] = np.ones_like(ex['f0_hz']) |
| 194 | + beam.metrics.Metrics.counter('prepare-urmp', 'ddsp-features-attached').inc() |
| 195 | + return ex |
| 196 | + |
| 197 | + |
| 198 | +def serialize_tfexample(ex): |
| 199 | + """Creates a tf.Example message ready to be written to a file.""" |
| 200 | + |
| 201 | + def _feature(arr): |
| 202 | + """Returns a feature from a numpy array or string.""" |
| 203 | + if isinstance(arr, (bytes, str)): |
| 204 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[arr])) |
| 205 | + else: |
| 206 | + arr = np.asarray(arr).reshape(-1) |
| 207 | + return tf.train.Feature(float_list=tf.train.FloatList(value=arr)) |
| 208 | + |
| 209 | + # Create a dictionary mapping the feature name to the tf.Example-compatible |
| 210 | + # data type. |
| 211 | + feature = {k: _feature(v) for k, v in ex.items()} |
| 212 | + |
| 213 | + # Create a Features message using tf.train.Example. |
| 214 | + example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 215 | + return example_proto |
| 216 | + |
| 217 | + |
| 218 | +def prepare_urmp(input_filepath, |
| 219 | + output_filepath, |
| 220 | + instrument_keys, |
| 221 | + num_shards, |
| 222 | + batch, |
| 223 | + force_monophonic, |
| 224 | + pipeline_options, |
| 225 | + ddsp_sample_rate=DDSP_SAMPLE_RATE, |
| 226 | + audio_sample_rate=AUDIO_SAMPLE_RATE): |
| 227 | + """Pipeline for parsing URMP dataset to a usable format for DDSP.""" |
| 228 | + pipeline_options = beam.options.pipeline_options.PipelineOptions( |
| 229 | + pipeline_options) |
| 230 | + with beam.Pipeline(options=pipeline_options) as pipeline: |
| 231 | + examples = ( |
| 232 | + pipeline |
| 233 | + | |
| 234 | + 'read_tfrecords' >> beam.io.tfrecordio.ReadFromTFRecord(input_filepath) |
| 235 | + | 'parse_example' >> beam.Map(parse_example) |
| 236 | + | 'attach_metadata' >> beam.Map( |
| 237 | + attach_metadata, |
| 238 | + ddsp_sample_rate=ddsp_sample_rate, |
| 239 | + audio_sample_rate=audio_sample_rate, |
| 240 | + force_monophonic=force_monophonic)) |
| 241 | + |
| 242 | + if instrument_keys: |
| 243 | + examples |= 'filter_instruments' >> beam.Filter( |
| 244 | + lambda ex: ex['instrument_id'].decode() in instrument_keys) |
| 245 | + |
| 246 | + examples |= 'resample' >> beam.Map( |
| 247 | + resample, |
| 248 | + ddsp_sample_rate=ddsp_sample_rate, |
| 249 | + audio_sample_rate=audio_sample_rate) |
| 250 | + if batch: |
| 251 | + examples |= 'batch' >> beam.FlatMap( |
| 252 | + batch_dataset, |
| 253 | + audio_sample_rate=audio_sample_rate, |
| 254 | + ddsp_sample_rate=ddsp_sample_rate) |
| 255 | + _ = ( |
| 256 | + examples |
| 257 | + | 'attach_ddsp_features' >> beam.Map(attach_ddsp_features) |
| 258 | + | 'filter_silence' >> |
| 259 | + beam.Filter(lambda ex: np.any(ex['loudness_db'] > -70)) |
| 260 | + | 'serialize_tfexamples' >> beam.Map(serialize_tfexample) |
| 261 | + | 'shuffle' >> beam.Reshuffle() |
| 262 | + | beam.io.tfrecordio.WriteToTFRecord( |
| 263 | + output_filepath, |
| 264 | + num_shards=num_shards, |
| 265 | + coder=beam.coders.ProtoCoder(tf.train.Example))) |
0 commit comments