forked from tensorflow/datasets
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
15 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The TensorFlow Datasets Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Groove Midi Dataset (GMD).""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import collections | ||
import copy | ||
import csv | ||
import io | ||
import os | ||
|
||
from absl import logging | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_datasets.public_api as tfds | ||
|
||
_DESCRIPTION = """\ | ||
The Groove MIDI Dataset (GMD) is composed of 13.6 hours of aligned MIDI and | ||
(synthesized) audio of human-performed, tempo-aligned expressive drumming | ||
captured on a Roland TD-11 V-Drum electronic drum kit. | ||
""" | ||
|
||
_CITATION = """ | ||
@inproceedings{groove2019, | ||
Author = {Jon Gillick and Adam Roberts and Jesse Engel and Douglas Eck and David Bamman}, | ||
Title = {Learning to Groove with Inverse Sequence Transformations}, | ||
Booktitle = {International Conference on Machine Learning (ICML)} | ||
Year = {2019}, | ||
} | ||
""" | ||
|
||
_PRIMARY_STYLES = [ | ||
"afrobeat", "afrocuban", "blues", "country", "dance", "funk", "gospel", | ||
"highlife", "hiphop", "jazz", "latin", "middleeastern", "neworleans", "pop", | ||
"punk", "reggae", "rock", "soul"] | ||
|
||
_TIME_SIGNATURES = ["3-4", "4-4", "5-4", "5-8", "6-8"] | ||
|
||
_DOWNLOAD_URL = "https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0.zip" | ||
_DOWNLOAD_URL_MIDI_ONLY = "https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip" | ||
|
||
|
||
class GrooveConfig(tfds.core.BuilderConfig): | ||
"""BuilderConfig for Groove Dataset.""" | ||
|
||
def __init__(self, split_bars=None, include_audio=True, audio_rate=16000, | ||
**kwargs): | ||
"""Constructs a GrooveConfig. | ||
Args: | ||
split_bars: int, number of bars to include per example using a sliding | ||
window across the raw data, or will not split if None. | ||
include_audio: bool, whether to include audio in the examples. If True, | ||
examples with missing audio will be excluded. | ||
audio_rate: int, sample rate to use for audio. | ||
**kwargs: keyword arguments forwarded to super. | ||
""" | ||
name_parts = [("%dbar" % split_bars) if split_bars else "full"] | ||
if include_audio: | ||
name_parts.append("%dhz" % audio_rate) | ||
else: | ||
name_parts.append("midionly") | ||
|
||
super(GrooveConfig, self).__init__(name="-".join(name_parts), **kwargs) | ||
self.split_bars = split_bars | ||
self.include_audio = include_audio | ||
self.audio_rate = audio_rate | ||
|
||
|
||
class Groove(tfds.core.GeneratorBasedBuilder): | ||
"""The Groove MIDI Dataset (GMD) of drum performances.""" | ||
|
||
BUILDER_CONFIGS = [ | ||
GrooveConfig( | ||
include_audio=False, | ||
version="1.0.0", | ||
description="Groove dataset without audio, unsplit." | ||
), | ||
GrooveConfig( | ||
include_audio=True, | ||
version="1.0.0", | ||
description="Groove dataset with audio, unsplit." | ||
), | ||
GrooveConfig( | ||
include_audio=False, | ||
split_bars=2, | ||
version="1.0.0", | ||
description="Groove dataset without audio, split into 2-bar chunks." | ||
), | ||
GrooveConfig( | ||
include_audio=True, | ||
split_bars=2, | ||
version="1.0.0", | ||
description="Groove dataset with audio, split into 2-bar chunks." | ||
), | ||
GrooveConfig( | ||
include_audio=False, | ||
split_bars=4, | ||
version="1.0.0", | ||
description="Groove dataset without audio, split into 4-bar chunks." | ||
), | ||
] | ||
|
||
def _info(self): | ||
features_dict = { | ||
"id": tf.string, | ||
"drummer": | ||
tfds.features.ClassLabel( | ||
names=["drummer%d" % i for i in range(1, 11)]), | ||
"type": tfds.features.ClassLabel(names=["beat", "fill"]), | ||
"bpm": tf.int32, | ||
"time_signature": tfds.features.ClassLabel(names=_TIME_SIGNATURES), | ||
"style": { | ||
"primary": tfds.features.ClassLabel(names=_PRIMARY_STYLES), | ||
"secondary": tf.string, | ||
}, | ||
"midi": tf.string | ||
} | ||
if self.builder_config.include_audio: | ||
features_dict["audio"] = tfds.features.Tensor( | ||
shape=[None], dtype=tf.float32) | ||
return tfds.core.DatasetInfo( | ||
builder=self, | ||
description=_DESCRIPTION, | ||
features=tfds.features.FeaturesDict(features_dict), | ||
urls=["https://g.co/magenta/groove-dataset"], | ||
citation=_CITATION, | ||
) | ||
|
||
def _split_generators(self, dl_manager): | ||
"""Returns splits.""" | ||
# Download data. | ||
data_dir = os.path.join( | ||
dl_manager.download_and_extract( | ||
_DOWNLOAD_URL if self._builder_config.include_audio else | ||
_DOWNLOAD_URL_MIDI_ONLY), | ||
"groove") | ||
|
||
rows = collections.defaultdict(list) | ||
with tf.io.gfile.GFile(os.path.join(data_dir, "info.csv")) as f: | ||
reader = csv.DictReader(f) | ||
for row in reader: | ||
rows[row["split"]].append(row) | ||
|
||
return [ | ||
tfds.core.SplitGenerator( # pylint: disable=g-complex-comprehension | ||
name=split, | ||
num_shards=10 if split == "train" else 1, | ||
gen_kwargs={"rows": split_rows, "data_dir": data_dir}) | ||
for split, split_rows in rows.items()] | ||
|
||
def _generate_examples(self, rows, data_dir): | ||
split_bars = self._builder_config.split_bars | ||
for row in rows: | ||
split_genre = row["style"].split("/") | ||
with tf.io.gfile.GFile( | ||
os.path.join(data_dir, row["midi_filename"]), "rb") as midi_f: | ||
midi = midi_f.read() | ||
audio = None | ||
if self._builder_config.include_audio: | ||
if not row["audio_filename"]: | ||
# Skip examples with no audio. | ||
logging.warning("Skipping example with no audio: %s", row["id"]) | ||
continue | ||
wav_path = os.path.join(data_dir, row["audio_filename"]) | ||
audio = _load_wav(wav_path, self._builder_config.audio_rate) | ||
|
||
example = { | ||
"id": row["id"], | ||
"drummer": row["drummer"], | ||
"type": row["beat_type"], | ||
"bpm": int(row["bpm"]), | ||
"time_signature": row["time_signature"], | ||
"style": { | ||
"primary": split_genre[0], | ||
"secondary": split_genre[1] if len(split_genre) == 2 else "" | ||
}, | ||
} | ||
if not split_bars: | ||
# Yield full example. | ||
example["midi"] = midi | ||
if audio is not None: | ||
example["audio"] = audio | ||
yield example | ||
else: | ||
# Yield split examples. | ||
bpm = int(row["bpm"]) | ||
beats_per_bar = int(row["time_signature"].split("-")[0]) | ||
bar_duration = 60 / bpm * beats_per_bar | ||
audio_rate = self._builder_config.audio_rate | ||
|
||
pm = tfds.core.lazy_imports.pretty_midi.PrettyMIDI(io.BytesIO(midi)) | ||
total_duration = pm.get_end_time() | ||
|
||
# Pad final bar if at least half filled. | ||
total_bars = int(round(total_duration / bar_duration)) | ||
total_frames = int(total_bars * bar_duration * audio_rate) | ||
if audio is not None and len(audio) < total_frames: | ||
audio = np.pad(audio, [0, total_frames - len(audio)], "constant") | ||
|
||
for i in range(total_bars - split_bars + 1): | ||
time_range = [i * bar_duration, (i + split_bars) * bar_duration] | ||
|
||
# Split MIDI. | ||
pm_split = copy.deepcopy(pm) | ||
pm_split.adjust_times(time_range, [0, split_bars * bar_duration]) | ||
pm_split.time_signature_changes = pm.time_signature_changes | ||
midi_split = io.BytesIO() | ||
pm_split.write(midi_split) | ||
example["midi"] = midi_split.getvalue() | ||
|
||
# Split audio. | ||
if audio is not None: | ||
example["audio"] = audio[ | ||
int(time_range[0] * audio_rate): | ||
int(time_range[1] * audio_rate)] | ||
|
||
example["id"] += ":%03d" % i | ||
yield example | ||
|
||
|
||
def _load_wav(path, sample_rate): | ||
with tf.io.gfile.GFile(path, "rb") as audio_f: | ||
audio_segment = tfds.core.lazy_imports.pydub.AudioSegment.from_file( | ||
audio_f, format="wav").set_channels(1).set_frame_rate(sample_rate) | ||
audio = np.array(audio_segment.get_array_of_samples()).astype(np.float32) | ||
# Convert from int to float representation. | ||
audio /= 2**(8 * audio_segment.sample_width) | ||
return audio | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The TensorFlow Datasets Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for Groove dataset module.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from tensorflow_datasets import testing | ||
from tensorflow_datasets.audio import groove | ||
|
||
|
||
class GrooveFullTest(testing.DatasetBuilderTestCase): | ||
DATASET_CLASS = groove.Groove | ||
BUILDER_CONFIG_NAMES_TO_TEST = ["full-16000hz"] | ||
SPLITS = { | ||
"train": 2, | ||
"test": 1, | ||
} | ||
DL_EXTRACT_RESULT = ".." | ||
|
||
|
||
class GrooveFullMidiOnlyTest(testing.DatasetBuilderTestCase): | ||
DATASET_CLASS = groove.Groove | ||
BUILDER_CONFIG_NAMES_TO_TEST = ["full-midionly"] | ||
SPLITS = { | ||
"train": 3, | ||
"test": 1, | ||
} | ||
DL_EXTRACT_RESULT = ".." | ||
|
||
|
||
class Groove2BarTest(testing.DatasetBuilderTestCase): | ||
DATASET_CLASS = groove.Groove | ||
BUILDER_CONFIG_NAMES_TO_TEST = ["2bar-16000hz"] | ||
SPLITS = { | ||
"train": 5, # 3, 2 | ||
"test": 1, | ||
} | ||
DL_EXTRACT_RESULT = ".." | ||
|
||
|
||
class Groove2BarMidiOnlyTest(testing.DatasetBuilderTestCase): | ||
DATASET_CLASS = groove.Groove | ||
BUILDER_CONFIG_NAMES_TO_TEST = ["2bar-midionly"] | ||
SPLITS = { | ||
"train": 6, # 3, 2, 1 | ||
"test": 1, | ||
} | ||
DL_EXTRACT_RESULT = ".." | ||
|
||
|
||
if __name__ == "__main__": | ||
testing.test_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+727 Bytes
...atasets/testing/test_data/fake_examples/groove/drummer3/session1/23_jazz_120_beat_4-4.mid
Binary file not shown.
Binary file added
BIN
+693 KB
...atasets/testing/test_data/fake_examples/groove/drummer3/session1/23_jazz_120_beat_4-4.wav
Binary file not shown.
Binary file added
BIN
+271 Bytes
...atasets/testing/test_data/fake_examples/groove/drummer3/session1/44_rock_120_beat_4-4.mid
Binary file not shown.
Binary file added
BIN
+347 KB
...atasets/testing/test_data/fake_examples/groove/drummer3/session1/44_rock_120_beat_4-4.wav
Binary file not shown.
Binary file added
BIN
+467 Bytes
..._datasets/testing/test_data/fake_examples/groove/drummer7/session1/3_rock_86_fill_4-4.mid
Binary file not shown.
Binary file added
BIN
+1.18 MB
..._datasets/testing/test_data/fake_examples/groove/drummer7/session1/3_rock_86_fill_4-4.wav
Binary file not shown.
Binary file added
BIN
+305 Bytes
.../testing/test_data/fake_examples/groove/drummer7/session2/116_jazz-fusion_96_fill_4-4.mid
Binary file not shown.
5 changes: 5 additions & 0 deletions
5
tensorflow_datasets/testing/test_data/fake_examples/groove/info.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split | ||
drummer3,drummer3/session1,drummer3/session1/23,jazz,120,beat,4-4,drummer3/session1/23_jazz_120_beat_4-4.mid,drummer3/session1/23_jazz_120_beat_4-4.wav,7.96875,train | ||
drummer3,drummer3/session1,drummer3/session1/44,rock/hard,120,fill,4-4,drummer3/session1/44_rock_120_beat_4-4.mid,drummer3/session1/44_rock_120_beat_4-4.wav,3.986458,test | ||
drummer7,drummer7/session1,drummer7/session1/3,rock,86,fill,3-4,drummer7/session1/3_rock_86_fill_4-4.mid,drummer7/session1/3_rock_86_fill_4-4.wav,5.696218,train | ||
drummer7,drummer7/session2,drummer7/session2/116,jazz/fusion,96,beat,4-4,drummer7/session2/116_jazz-fusion_96_fill_4-4.mid,,4.41276,train |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip 3260318 651cbc524ffb891be1a3e46d89dc82a1cecb09a57c748c7b45b844c4841dcc1e | ||
https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0.zip 5111599714 21559feb2f1c96ca53988fd4d7060b1f2afe1d854fb2a8dcea5ff95cf3cce7e9 |