Skip to content

Commit

Permalink
Add Groove to TFDS. Take 2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 246058528
  • Loading branch information
adarob authored and copybara-github committed May 1, 2019
1 parent 6393a6b commit dbda817
Show file tree
Hide file tree
Showing 15 changed files with 328 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
'scikit-image',
'scipy'
],
'groove': ['pretty_midi', 'pydub'],
'librispeech': ['pydub'], # and ffmpeg installed
'svhn': ['scipy'],
'wikipedia': ['mwparserfromhell', 'apache_beam'],
Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Audio datasets."""

from tensorflow_datasets.audio.groove import Groove
from tensorflow_datasets.audio.librispeech import Librispeech
from tensorflow_datasets.audio.librispeech import LibrispeechConfig
from tensorflow_datasets.audio.nsynth import Nsynth
246 changes: 246 additions & 0 deletions tensorflow_datasets/audio/groove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Groove Midi Dataset (GMD)."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import copy
import csv
import io
import os

from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_datasets.public_api as tfds

_DESCRIPTION = """\
The Groove MIDI Dataset (GMD) is composed of 13.6 hours of aligned MIDI and
(synthesized) audio of human-performed, tempo-aligned expressive drumming
captured on a Roland TD-11 V-Drum electronic drum kit.
"""

_CITATION = """
@inproceedings{groove2019,
Author = {Jon Gillick and Adam Roberts and Jesse Engel and Douglas Eck and David Bamman},
Title = {Learning to Groove with Inverse Sequence Transformations},
Booktitle = {International Conference on Machine Learning (ICML)}
Year = {2019},
}
"""

_PRIMARY_STYLES = [
"afrobeat", "afrocuban", "blues", "country", "dance", "funk", "gospel",
"highlife", "hiphop", "jazz", "latin", "middleeastern", "neworleans", "pop",
"punk", "reggae", "rock", "soul"]

_TIME_SIGNATURES = ["3-4", "4-4", "5-4", "5-8", "6-8"]

_DOWNLOAD_URL = "https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0.zip"
_DOWNLOAD_URL_MIDI_ONLY = "https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip"


class GrooveConfig(tfds.core.BuilderConfig):
"""BuilderConfig for Groove Dataset."""

def __init__(self, split_bars=None, include_audio=True, audio_rate=16000,
**kwargs):
"""Constructs a GrooveConfig.
Args:
split_bars: int, number of bars to include per example using a sliding
window across the raw data, or will not split if None.
include_audio: bool, whether to include audio in the examples. If True,
examples with missing audio will be excluded.
audio_rate: int, sample rate to use for audio.
**kwargs: keyword arguments forwarded to super.
"""
name_parts = [("%dbar" % split_bars) if split_bars else "full"]
if include_audio:
name_parts.append("%dhz" % audio_rate)
else:
name_parts.append("midionly")

super(GrooveConfig, self).__init__(name="-".join(name_parts), **kwargs)
self.split_bars = split_bars
self.include_audio = include_audio
self.audio_rate = audio_rate


class Groove(tfds.core.GeneratorBasedBuilder):
"""The Groove MIDI Dataset (GMD) of drum performances."""

BUILDER_CONFIGS = [
GrooveConfig(
include_audio=False,
version="1.0.0",
description="Groove dataset without audio, unsplit."
),
GrooveConfig(
include_audio=True,
version="1.0.0",
description="Groove dataset with audio, unsplit."
),
GrooveConfig(
include_audio=False,
split_bars=2,
version="1.0.0",
description="Groove dataset without audio, split into 2-bar chunks."
),
GrooveConfig(
include_audio=True,
split_bars=2,
version="1.0.0",
description="Groove dataset with audio, split into 2-bar chunks."
),
GrooveConfig(
include_audio=False,
split_bars=4,
version="1.0.0",
description="Groove dataset without audio, split into 4-bar chunks."
),
]

def _info(self):
features_dict = {
"id": tf.string,
"drummer":
tfds.features.ClassLabel(
names=["drummer%d" % i for i in range(1, 11)]),
"type": tfds.features.ClassLabel(names=["beat", "fill"]),
"bpm": tf.int32,
"time_signature": tfds.features.ClassLabel(names=_TIME_SIGNATURES),
"style": {
"primary": tfds.features.ClassLabel(names=_PRIMARY_STYLES),
"secondary": tf.string,
},
"midi": tf.string
}
if self.builder_config.include_audio:
features_dict["audio"] = tfds.features.Tensor(
shape=[None], dtype=tf.float32)
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict(features_dict),
urls=["https://g.co/magenta/groove-dataset"],
citation=_CITATION,
)

def _split_generators(self, dl_manager):
"""Returns splits."""
# Download data.
data_dir = os.path.join(
dl_manager.download_and_extract(
_DOWNLOAD_URL if self._builder_config.include_audio else
_DOWNLOAD_URL_MIDI_ONLY),
"groove")

rows = collections.defaultdict(list)
with tf.io.gfile.GFile(os.path.join(data_dir, "info.csv")) as f:
reader = csv.DictReader(f)
for row in reader:
rows[row["split"]].append(row)

return [
tfds.core.SplitGenerator( # pylint: disable=g-complex-comprehension
name=split,
num_shards=10 if split == "train" else 1,
gen_kwargs={"rows": split_rows, "data_dir": data_dir})
for split, split_rows in rows.items()]

def _generate_examples(self, rows, data_dir):
split_bars = self._builder_config.split_bars
for row in rows:
split_genre = row["style"].split("/")
with tf.io.gfile.GFile(
os.path.join(data_dir, row["midi_filename"]), "rb") as midi_f:
midi = midi_f.read()
audio = None
if self._builder_config.include_audio:
if not row["audio_filename"]:
# Skip examples with no audio.
logging.warning("Skipping example with no audio: %s", row["id"])
continue
wav_path = os.path.join(data_dir, row["audio_filename"])
audio = _load_wav(wav_path, self._builder_config.audio_rate)

example = {
"id": row["id"],
"drummer": row["drummer"],
"type": row["beat_type"],
"bpm": int(row["bpm"]),
"time_signature": row["time_signature"],
"style": {
"primary": split_genre[0],
"secondary": split_genre[1] if len(split_genre) == 2 else ""
},
}
if not split_bars:
# Yield full example.
example["midi"] = midi
if audio is not None:
example["audio"] = audio
yield example
else:
# Yield split examples.
bpm = int(row["bpm"])
beats_per_bar = int(row["time_signature"].split("-")[0])
bar_duration = 60 / bpm * beats_per_bar
audio_rate = self._builder_config.audio_rate

pm = tfds.core.lazy_imports.pretty_midi.PrettyMIDI(io.BytesIO(midi))
total_duration = pm.get_end_time()

# Pad final bar if at least half filled.
total_bars = int(round(total_duration / bar_duration))
total_frames = int(total_bars * bar_duration * audio_rate)
if audio is not None and len(audio) < total_frames:
audio = np.pad(audio, [0, total_frames - len(audio)], "constant")

for i in range(total_bars - split_bars + 1):
time_range = [i * bar_duration, (i + split_bars) * bar_duration]

# Split MIDI.
pm_split = copy.deepcopy(pm)
pm_split.adjust_times(time_range, [0, split_bars * bar_duration])
pm_split.time_signature_changes = pm.time_signature_changes
midi_split = io.BytesIO()
pm_split.write(midi_split)
example["midi"] = midi_split.getvalue()

# Split audio.
if audio is not None:
example["audio"] = audio[
int(time_range[0] * audio_rate):
int(time_range[1] * audio_rate)]

example["id"] += ":%03d" % i
yield example


def _load_wav(path, sample_rate):
with tf.io.gfile.GFile(path, "rb") as audio_f:
audio_segment = tfds.core.lazy_imports.pydub.AudioSegment.from_file(
audio_f, format="wav").set_channels(1).set_frame_rate(sample_rate)
audio = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
# Convert from int to float representation.
audio /= 2**(8 * audio_segment.sample_width)
return audio

67 changes: 67 additions & 0 deletions tensorflow_datasets/audio/groove_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Groove dataset module."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow_datasets import testing
from tensorflow_datasets.audio import groove


class GrooveFullTest(testing.DatasetBuilderTestCase):
DATASET_CLASS = groove.Groove
BUILDER_CONFIG_NAMES_TO_TEST = ["full-16000hz"]
SPLITS = {
"train": 2,
"test": 1,
}
DL_EXTRACT_RESULT = ".."


class GrooveFullMidiOnlyTest(testing.DatasetBuilderTestCase):
DATASET_CLASS = groove.Groove
BUILDER_CONFIG_NAMES_TO_TEST = ["full-midionly"]
SPLITS = {
"train": 3,
"test": 1,
}
DL_EXTRACT_RESULT = ".."


class Groove2BarTest(testing.DatasetBuilderTestCase):
DATASET_CLASS = groove.Groove
BUILDER_CONFIG_NAMES_TO_TEST = ["2bar-16000hz"]
SPLITS = {
"train": 5, # 3, 2
"test": 1,
}
DL_EXTRACT_RESULT = ".."


class Groove2BarMidiOnlyTest(testing.DatasetBuilderTestCase):
DATASET_CLASS = groove.Groove
BUILDER_CONFIG_NAMES_TO_TEST = ["2bar-midionly"]
SPLITS = {
"train": 6, # 3, 2, 1
"test": 1,
}
DL_EXTRACT_RESULT = ".."


if __name__ == "__main__":
testing.test_main()
5 changes: 5 additions & 0 deletions tensorflow_datasets/core/lazy_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def PIL_Image(cls): # pylint: disable=invalid-name
_try_import("PIL.TiffImagePlugin")
return _try_import("PIL.Image")

@utils.classproperty
@classmethod
def pretty_midi(cls):
return _try_import("pretty_midi")

@utils.classproperty
@classmethod
def pyplot(cls):
Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/core/lazy_imports_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class LazyImportsTest(testing.TestCase, parameterized.TestCase):
"matplotlib",
"mwparserfromhell",
"os",
"pretty_midi",
"pydub",
"pyplot",
"scipy",
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split
drummer3,drummer3/session1,drummer3/session1/23,jazz,120,beat,4-4,drummer3/session1/23_jazz_120_beat_4-4.mid,drummer3/session1/23_jazz_120_beat_4-4.wav,7.96875,train
drummer3,drummer3/session1,drummer3/session1/44,rock/hard,120,fill,4-4,drummer3/session1/44_rock_120_beat_4-4.mid,drummer3/session1/44_rock_120_beat_4-4.wav,3.986458,test
drummer7,drummer7/session1,drummer7/session1/3,rock,86,fill,3-4,drummer7/session1/3_rock_86_fill_4-4.mid,drummer7/session1/3_rock_86_fill_4-4.wav,5.696218,train
drummer7,drummer7/session2,drummer7/session2/116,jazz/fusion,96,beat,4-4,drummer7/session2/116_jazz-fusion_96_fill_4-4.mid,,4.41276,train
2 changes: 2 additions & 0 deletions tensorflow_datasets/url_checksums/groove.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip 3260318 651cbc524ffb891be1a3e46d89dc82a1cecb09a57c748c7b45b844c4841dcc1e
https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0.zip 5111599714 21559feb2f1c96ca53988fd4d7060b1f2afe1d854fb2a8dcea5ff95cf3cce7e9

0 comments on commit dbda817

Please sign in to comment.