diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..9ee5bcb4 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +v1/ linguist-detectable=false diff --git a/.gitignore b/.gitignore index 1b4a6f5a..3753ae38 100644 --- a/.gitignore +++ b/.gitignore @@ -101,4 +101,11 @@ ENV/ .mypy_cache/ # vscode -.vscode/* +/.vscode/ + +# Experiments +/data/ +/preprocessing/ +/exp/ +logs/ +BACKUP/ diff --git a/README.md b/README.md index e5d4358b..ff87318d 100755 --- a/README.md +++ b/README.md @@ -16,68 +16,84 @@ drums, guitar, piano and strings tracks. Sample results are available [here](https://salu133445.github.io/musegan/results). -## Papers +## BinaryMuseGAN -Hao-Wen Dong\*, Wen-Yi Hsiao\*, Li-Chia Yang and Yi-Hsuan Yang, -"**MuseGAN: Multi-track Sequential Generative Adversarial Networks for -Symbolic Music Generation and Accompaniment**," -in *AAAI Conference on Artificial Intelligence* (AAAI), 2018. -[[arxiv](http://arxiv.org/abs/1709.06298)] -[[slides](https://salu133445.github.io/musegan/pdf/musegan-aaai2018-slides.pdf)] - -Hao-Wen Dong\*, Wen-Yi Hsiao\*, Li-Chia Yang and Yi-Hsuan Yang, -"**MuseGAN: Demonstration of a Convolutional GAN Based Model for Generating -Multi-track Piano-rolls**," -in *ISMIR Late-Breaking and Demo Session*, 2017. -(non-peer reviewed two-page extended abstract) -[[paper](https://salu133445.github.io/musegan/pdf/musegan-ismir2017-lbd-paper.pdf)] -[[poster](https://salu133445.github.io/musegan/pdf/musegan-ismir2017-lbd-poster.pdf)] +[BinaryMuseGAN](https://salu133445.github.io/bmusegan/) is a follow-up project +of the [MuseGAN](https://salu133445.github.io/musegan/) project. -\* *These authors contributed equally to this work.* +In this project, we first investigate how the real-valued piano-rolls generated +by the generator may lead to difficulties in training the discriminator for +CNN-based models. To overcome the binarization issue, we propose to append to +the generator an additional refiner network, which try to refine the real-valued +predictions generated by the pretrained generator to binary-valued ones. The +proposed model is able to directly generate binary-valued piano-rolls at test +time. -## Usage +We trained the network with +[Lakh Pianoroll Dataset](https://salu133445.github.io/lakh-pianoroll-dataset/) +(LPD). We use the model to generate four-bar musical phrases consisting of eight +tracks: *Drums*, *Piano*, *Guitar*, *Bass*, *Ensemble*, *Reed*, *Synth Lead* and +*Synth Pad*. Audio samples are available +[here](https://salu133445.github.io/bmusegan/samples). -```python -import tensorflow as tf -from musegan.core import MuseGAN -from musegan.components import NowbarHybrid -from config import * +## Run the code -# Initialize a tensorflow session -with tf.Session() as sess: +### Configuration - # === Prerequisites === - # Step 1 - Initialize the training configuration - t_config = TrainingConfig +Modify `config.py` for configuration. - # Step 2 - Select the desired model - model = NowbarHybrid(NowBarHybridConfig) +- Quick setup - # Step 3 - Initialize the input data object - input_data = InputDataNowBarHybrid(model) + Change the values in the dictionary `SETUP` for a quick setup. Documentation + is provided right after each key. - # Step 4 - Load training data - path_train = 'train.npy' - input_data.add_data(path_train, key='train') +- More configuration options - # Step 5 - Initialize a museGAN object - musegan = MuseGAN(sess, t_config, model) + Four dictionaries `EXP_CONFIG`, `DATA_CONFIG`, `MODEL_CONFIG` and + `TRAIN_CONFIG` define experiment-, data-, model- and training-related + configuration variables, respectively. - # === Training === - musegan.train(input_data) + > The automatically-determined experiment name is based only on the values +defined in the dictionary `SETUP`, so remember to provide the experiment name +manually (so that you won't overwrite a trained model). - # === Load a Pretrained Model === - musegan.load(musegan.dir_ckpt) +### Run - # === Generate Samples === - path_test = 'train.npy' - input_data.add_data(path_test, key='test') - musegan.gen_test(input_data, is_eval=True) +```sh +python main.py ``` ## Training Data -- [tra_phr.npy](https://drive.google.com/uc?id=1-bQCO6ZxpIgdMM7zXhNJViovHjtBKXde&export=download) - (7.54 GB) contains 50,266 four-bar phrases. The shape is (50266, 384, 84, 5). -- [tra_bar.npy](https://drive.google.com/uc?id=1Xxj6WU82fcgY9UtBpXJGOspoUkMu58xC&export=download) - (4.79 GB) contains 127,734 bars. The shape is (127734, 96, 84, 5). +- Prepare your own data + + The array will be reshaped to (-1, `num_bar`, `num_timestep`, `num_pitch`, + `num_track`). These variables are defined in `config.py`. + +- Download our training data with this [script](training_data/download.sh) or + download it manually [here](https://salu133445.github.io/musegan/data). + +## Papers + +- Hao-Wen Dong and Yi-Hsuan Yang, + "Convolutional Generative Adversarial Networks with Binary Neurons for + Polyphonic Music Generation", + *arXiv preprint, arXiv:1804.09399*, 2018. + [[arxiv](https://arxiv.org/abs/1804.09399)] + +- Hao-Wen Dong\*, Wen-Yi Hsiao\*, Li-Chia Yang and Yi-Hsuan Yang, + "MuseGAN: Multi-track Sequential Generative Adversarial Networks for + Symbolic Music Generation and Accompaniment," + in *AAAI Conference on Artificial Intelligence* (AAAI), 2018. + [[arxiv](http://arxiv.org/abs/1709.06298)] + [[slides](https://salu133445.github.io/musegan/pdf/musegan-aaai2018-slides.pdf)] + +- Hao-Wen Dong\*, Wen-Yi Hsiao\*, Li-Chia Yang and Yi-Hsuan Yang, + "MuseGAN: Demonstration of a Convolutional GAN Based Model for Generating + Multi-track Piano-rolls," + in *ISMIR Late-Breaking and Demo Session*, 2017. + (non-peer reviewed two-page extended abstract) + [[paper](https://salu133445.github.io/musegan/pdf/musegan-ismir2017-lbd-paper.pdf)] + [[poster](https://salu133445.github.io/musegan/pdf/musegan-ismir2017-lbd-poster.pdf)] + +\* *These authors contributed equally to this work.* diff --git a/config.py b/config.py index 28199cf2..7c090317 100644 --- a/config.py +++ b/config.py @@ -1,170 +1,338 @@ -''' -Model Configuration -''' -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Define configuration variables in experiment, model and training levels. -import numpy as np -from shutil import copyfile +Quick Setup +=========== +Change the values in the dictionary `SETUP` for a quick setup. +Documentation is provided right after each key. + +Configuration +============= +More configuration options are providedin as a dictionary `CONFIG`. +`CONFIG['exp']`, `CONFIG['data']`, `CONFIG['model']`, `CONFIG['train']` and +`CONFIG['tensorflow']` define experiment-, data-, model-, training-, +TensorFlow-related configuration variables, respectively. + +Note that the automatically-determined experiment name is based only on the +values defined in the dictionary `SETUP`, so remember to provide the experiment +name manually if you have changed the configuration so that you won't overwrite +existing experiment directories. +""" import os -import SharedArray as sa +import shutil +import distutils.dir_util +import importlib +import numpy as np import tensorflow as tf -import glob - -print('[*] config...') - -# class Dataset: -TRACK_NAMES = ['bass', 'drums', 'guitar', 'piano', 'strings'] - -def get_colormap(): - colormap = np.array([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.], - [1., .5, 0.], - [0., .5, 1.]]) - return tf.constant(colormap, dtype=tf.float32, name='colormap') - -########################################################################### -# Training -########################################################################### - -class TrainingConfig: - is_eval = True - batch_size = 64 - epoch = 20 - iter_to_save = 100 - sample_size = 64 - print_batch = True - drum_filter = np.tile([1,0.3,0,0,0,0.3], 16) - scale_mask = [1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1.] - inter_pair = [(0,2), (0,3), (0,4), (2,3), (2,4), (3,4)] - track_names = TRACK_NAMES - track_dim = len(track_names) - eval_map = np.array([ - [1, 1, 1, 1, 1], # metric_is_empty_bar - [1, 1, 1, 1, 1], # metric_num_pitch_used - [1, 0, 1, 1, 1], # metric_too_short_note_ratio - [1, 0, 1, 1, 1], # metric_polyphonic_ratio - [1, 0, 1, 1, 1], # metric_in_scale - [0, 1, 0, 0, 0], # metric_drum_pattern - [1, 0, 1, 1, 1] # metric_num_chroma_used - ]) - - exp_name = 'exp' - gpu_num = '1' - - -########################################################################### -# Model Config -########################################################################### - -class ModelConfig: - output_w = 96 - output_h = 84 - lamda = 10 - batch_size = 64 - beta1 = 0.5 - beta2 = 0.9 - lr = 2e-4 - is_bn = True - colormap = get_colormap() - -# image -class MNISTConfig(ModelConfig): - output_w = 28 - output_h = 28 - z_dim = 74 - output_dim = 1 - -# RNN -class RNNConfig(ModelConfig): - track_names = ['All'] - track_dim = 1 - output_bar = 4 - z_inter_dim = 128 - output_dim = 5 - acc_idx = None - state_size = 128 - -# onebar -class OneBarHybridConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - acc_idx = None - z_inter_dim = 64 - z_intra_dim = 64 - output_dim = 1 - -class OneBarJammingConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - acc_idx = None - z_intra_dim = 128 - output_dim = 1 - -class OneBarComposerConfig(ModelConfig): - track_names = ['All'] - track_dim = 1 - acc_idx = None - z_inter_dim = 128 - output_dim = 5 - -# nowbar - -class NowBarHybridConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - acc_idx = 4 - z_inter_dim = 64 - z_intra_dim = 64 - output_dim = 1 - -class NowBarJammingConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - acc_idx = 4 - z_intra_dim = 128 - output_dim = 1 - -class NowBarComposerConfig(ModelConfig): - track_names = ['All'] - track_dim = 1 - acc_idx = 4 - z_inter_dim = 128 - output_dim = 5 - -# Temporal -class TemporalHybridConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - output_bar = 4 - z_inter_dim = 32 - z_intra_dim = 32 - acc_idx = None - output_dim = 1 - -class TemporalJammingConfig(ModelConfig): - track_names = TRACK_NAMES - track_dim = 5 - output_bar = 4 - z_intra_dim = 64 - output_dim = 1 - -class TemporalComposerConfig(ModelConfig): - track_names = ['All'] - track_dim = 1 - output_bar = 4 - z_inter_dim = 64 - acc_idx = None - output_dim = 5 - -class NowBarTemporalHybridConfig(ModelConfig): - track_names = TRACK_NAMES - acc_idx = 4 - track_dim = 5 - output_bar = 4 - z_inter_dim = 32 - z_intra_dim = 32 - acc_idx = 4 - output_dim = 1 + +# Quick setup +SETUP = { + 'model': 'musegan', + # {'musegan', 'bmusegan'} + # The model to use. Currently support MuseGAN and BinaryMuseGAN models. + + 'exp_name': None, + # The experiment name. Also the name of the folder that will be created + # in './exp/' and all the experiment-related files are saved in that + # folder. None to determine automatically. The automatically- + # determined experiment name is based only on the values defined in the + # dictionary `SETUP`, so remember to provide the experiment name manually + # (so that you won't overwrite a trained model). + + 'prefix': 'lastfm_alternative', + # Prefix for the experiment name. Useful when training with different + # training data to avoid replacing the previous experiment outputs. + + 'training_data': 'lastfm_alternative_8b_phrase', + # Filename of the training data. The training data can be loaded from a npy + # file in the hard disk or from the shared memory using SharedArray package. + # Note that the data will be reshaped to (-1, num_bar, num_timestep, + # num_pitch, num_track) and remember to set these variable to proper values, + # which are defined in `CONFIG['model']`. + + 'training_data_location': 'sa', + # Location of the training data. 'hd' to load from a npy file stored in the + # hard disk. 'sa' to load from shared array using SharedArray package. + + 'gpu': '0', + # The GPU index in os.environ['CUDA_VISIBLE_DEVICES'] to use. + + 'preset_g': 'hybrid', + # MuseGAN: {'composer', 'jamming', 'hybrid'} + # BinaryMuseGAN: {'proposed', 'proposed_small'} + # Use a preset network architecture for the generator or set to None and + # setup `CONFIG['model']['net_g']` to define the network architecture. + + 'preset_d': 'proposed', + # {'proposed', 'proposed_small', 'ablated', 'baseline', None} + # Use a preset network architecture for the discriminator or set to None + # and setup `CONFIG['model']['net_d']` to define the network architecture. + + 'pretrained_dir': None, + # The directory containing the pretrained model. None to retrain the + # model from scratch. + + 'verbose': True, + # True to print each batch details to stdout. False to print once an epoch. + + 'sample_along_training': True, + # True to generate samples along the training process. False for nothing. + + 'evaluate_along_training': True, + # True to run evaluation along the training process. False for nothing. + + # ------------------------- For BinaryMuseGAN only ------------------------- + 'two_stage_training': True, + # True to train the model in a two-stage training setting. False to + # train the model in an end-to-end manner. + + 'training_phase': 'first_stage', + # {'first_stage', 'second_stage'} + # The training phase in a two-stage training setting. Only effective + # when `two_stage_training` is True. + + 'first_stage_dir': None, + # The directory containing the pretrained first-stage model. None to + # determine automatically (assuming using default `exp_name`). Only + # effective when two_stage_training` is True and `training_phase` is + # 'second_stage'. + + 'joint_training': False, + # True to train the generator and the refiner jointly. Only effective + # when `two_stage_training` is True and `training_phase` is 'second_stage'. + + 'preset_r': 'proposed_bernoulli', + # {'proposed_round', 'proposed_bernoulli'} + # Use a preset network architecture for the refiner or set to None and + # setup `CONFIG['model']['net_r']` to define the network architecture. +} + +CONFIG = {} + +#=============================================================================== +#=========================== TensorFlow Configuration ========================== +#=============================================================================== +os.environ['CUDA_VISIBLE_DEVICES'] = SETUP['gpu'] +CONFIG['tensorflow'] = tf.ConfigProto() +CONFIG['tensorflow'].gpu_options.allow_growth = True + +#=============================================================================== +#========================== Experiment Configuration =========================== +#=============================================================================== +CONFIG['exp'] = { + 'model': None, + 'exp_name': None, + 'pretrained_dir': None, + 'two_stage_training': None, # For BinaryMuseGAN only + 'first_stage_dir': None, # For BinaryMuseGAN only +} + +for key in ('model', 'pretrained_dir'): + if CONFIG['exp'][key] is None: + CONFIG['exp'][key] = SETUP[key] + +if SETUP['model'] == 'musegan': + # Set default experiment name + if CONFIG['exp']['exp_name'] is None: + if SETUP['exp_name'] is not None: + CONFIG['exp']['exp_name'] = SETUP['exp_name'] + else: + CONFIG['exp']['exp_name'] = '_'.join( + (SETUP['prefix'], 'g', SETUP['preset_g'], 'd', + SETUP['preset_d'])) + +if SETUP['model'] == 'bmusegan': + if CONFIG['exp']['two_stage_training'] is None: + CONFIG['exp']['two_stage_training'] = SETUP['two_stage_training'] + # Set default experiment name + if CONFIG['exp']['exp_name'] is None: + if SETUP['exp_name'] is not None: + CONFIG['exp']['exp_name'] = SETUP['exp_name'] + elif not SETUP['two_stage_training']: + CONFIG['exp']['exp_name'] = '_'.join( + (SETUP['prefix'], 'end2end', 'g', SETUP['preset_g'], 'd', + SETUP['preset_d'], 'r', SETUP['preset_r'])) + elif SETUP['training_phase'] == 'first_stage': + CONFIG['exp']['exp_name'] = '_'.join( + (SETUP['prefix'], SETUP['training_phase'], 'g', + SETUP['preset_g'], 'd', SETUP['preset_d'])) + elif SETUP['training_phase'] == 'second_stage': + if SETUP['joint_training']: + CONFIG['exp']['exp_name'] = '_'.join( + (SETUP['prefix'], SETUP['training_phase'], 'joint', 'g', + SETUP['preset_g'], 'd', SETUP['preset_d'], 'r', + SETUP['preset_r'])) + else: + CONFIG['exp']['exp_name'] = '_'.join( + (SETUP['prefix'], SETUP['training_phase'], 'g', + SETUP['preset_g'], 'd', SETUP['preset_d'], 'r', + SETUP['preset_r'])) + # Set default first stage model directory + if CONFIG['exp']['first_stage_dir'] is None: + if SETUP['first_stage_dir'] is not None: + CONFIG['exp']['first_stage_dir'] = SETUP['first_stage_dir'] + else: + CONFIG['exp']['first_stage_dir'] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'exp', + '_'.join((SETUP['prefix'], 'first_stage', 'g', + SETUP['preset_g'], 'd', SETUP['preset_d'])), + 'checkpoints') + +#=============================================================================== +#============================= Data Configuration ============================== +#=============================================================================== +CONFIG['data'] = { + 'training_data': None, + 'training_data_location': None, +} + +for key in ('training_data', 'training_data_location'): + if CONFIG['data'][key] is None: + CONFIG['data'][key] = SETUP[key] + +#=============================================================================== +#=========================== Training Configuration ============================ +#=============================================================================== +CONFIG['train'] = { + 'num_epoch': 20, + 'verbose': None, + 'sample_along_training': None, + 'evaluate_along_training': None, + 'two_stage_training': None, # For BinaryMuseGAN only + 'training_phase': None, # For BinaryMuseGAN only + 'slope_annealing_rate': 1.1, # For BinaryMuseGAN only +} + +for key in ('verbose', 'sample_along_training', 'evaluate_along_training'): + if CONFIG['train'][key] is None: + CONFIG['train'][key] = SETUP[key] + +if SETUP['model'] == 'bmusegan' and CONFIG['train']['training_phase'] is None: + CONFIG['train']['training_phase'] = SETUP['training_phase'] + +#=============================================================================== +#============================= Model Configuration ============================= +#=============================================================================== +CONFIG['model'] = { + 'joint_training': None, # For BinaryMuseGAN only + + # Parameters + 'batch_size': 32, # Note: tf.layers.conv3d_transpose requires a fixed batch + # size in TensorFlow < 1.6 + 'gan': { + 'type': 'wgan-gp', # 'gan', 'wgan', 'wgan-gp' + 'clip_value': .01, + 'gp_coefficient': 10. + }, + 'optimizer': { + # Parameters for the Adam optimizers + 'lr': .002, + 'beta1': .5, + 'beta2': .9, + 'epsilon': 1e-8 + }, + + # Data + 'num_bar': 4, + 'num_beat': 4, + 'num_pitch': 84, + 'num_track': 8, + 'num_timestep': 96, + 'beat_resolution': 24, + 'lowest_pitch': 24, # MIDI note number of the lowest pitch in data tensors + + # Tracks + 'track_names': ( + 'Drums', 'Piano', 'Guitar', 'Bass', 'Ensemble', 'Reed', 'Synth Lead', + 'Synth Pad' + ), + 'programs': (0, 0, 24, 32, 48, 64, 80, 88), + 'is_drums': (True, False, False, False, False, False, False, False), + + # Network architectures (define them here if not using the presets) + 'net_g': None, + 'net_d': None, + 'net_r': None, # For BinaryMuseGAN only + + # Playback + 'pause_between_samples': 96, + 'tempo': 90., + + # Samples + 'num_sample': 16, + 'sample_grid': (2, 8), + + # Metrics + 'metric_map': np.array([ + # indices of tracks for the metrics to compute + [True] * 8, # empty bar rate + [True] * 8, # number of pitch used + [False] + [True] * 7, # qualified note rate + [False] + [True] * 7, # polyphonicity + [False] + [True] * 7, # in scale rate + [True] + [False] * 7, # in drum pattern rate + [False] + [True] * 7 # number of chroma used + ], dtype=bool), + 'tonal_distance_pairs': [(1, 2)], # pairs to compute the tonal distance + 'scale_mask': list(map(bool, [1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])), + 'drum_filter': np.tile([1., .1, 0., 0., 0., .1], 16), + 'tonal_matrix_coefficient': (1., 1., .5), + + # Directories + 'checkpoint_dir': None, + 'sample_dir': None, + 'eval_dir': None, + 'log_dir': None, + 'src_dir': None, +} + +if SETUP['model'] == 'bmusegan' and CONFIG['model']['joint_training'] is None: + CONFIG['model']['joint_training'] = SETUP['joint_training'] + +# Import preset network architectures +if CONFIG['model']['net_g'] is None: + IMPORTED = importlib.import_module( + '.'.join(('musegan', SETUP['model'], 'presets', 'generator', + SETUP['preset_g']))) + CONFIG['model']['net_g'] = IMPORTED.NET_G + +if CONFIG['model']['net_d'] is None: + IMPORTED = importlib.import_module( + '.'.join(('musegan', SETUP['model'], 'presets', 'discriminator', + SETUP['preset_d']))) + CONFIG['model']['net_d'] = IMPORTED.NET_D + +if SETUP['model'] == 'bmusegan' and CONFIG['model']['net_r'] is None: + IMPORTED = importlib.import_module( + '.'.join(('musegan.bmusegan.presets', 'refiner', SETUP['preset_r']))) + CONFIG['model']['net_r'] = IMPORTED.NET_R + +# Set default directories +for kv_pair in (('checkpoint_dir', 'checkpoints'), ('sample_dir', 'samples'), + ('eval_dir', 'eval'), ('log_dir', 'logs'), ('src_dir', 'src')): + if CONFIG['model'][kv_pair[0]] is None: + CONFIG['model'][kv_pair[0]] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'exp', SETUP['model'], + CONFIG['exp']['exp_name'], kv_pair[1]) + +#=============================================================================== +#=================== Make directories & Backup source code ===================== +#=============================================================================== +# Make sure directories exist +for path in (CONFIG['model']['checkpoint_dir'], CONFIG['model']['sample_dir'], + CONFIG['model']['eval_dir'], CONFIG['model']['log_dir'], + CONFIG['model']['src_dir']): + if not os.path.exists(path): + os.makedirs(path) + +# Backup source code +for path in os.listdir(os.path.dirname(os.path.realpath(__file__))): + if os.path.isfile(path): + if path.endswith('.py'): + shutil.copyfile(os.path.basename(path), + os.path.join(CONFIG['model']['src_dir'], + os.path.basename(path))) + +distutils.dir_util.copy_tree( + os.path.join(os.path.dirname(os.path.realpath(__file__)), 'musegan'), + os.path.join(CONFIG['model']['src_dir'], 'musegan') +) diff --git a/main.py b/main.py index dd8d0067..6e48d41a 100644 --- a/main.py +++ b/main.py @@ -1,47 +1,116 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import scipy.misc +"""Train the model +""" +import importlib import numpy as np import tensorflow as tf -from pprint import pprint import SharedArray as sa +from config import CONFIG +MODELS = importlib.import_module( + '.'.join(('musegan', CONFIG['exp']['model'], 'models'))) -from musegan.core import * -from musegan.components import * -from input_data import * -from config import * +def load_data(): + """Load and return the training data.""" + print('[*] Loading data...') -#assign GPU + # Load data from SharedArray + if CONFIG['data']['training_data_location'] == 'sa': + x_train = sa.attach(CONFIG['data']['training_data']) + # Load data from hard disk + elif CONFIG['data']['training_data_location'] == 'hd': + x_train = np.load(CONFIG['data']['training_data']) -if __name__ == '__main__': + # Reshape data + x_train = x_train.reshape( + -1, CONFIG['model']['num_bar'], CONFIG['model']['num_timestep'], + CONFIG['model']['num_pitch'], CONFIG['model']['num_track']) + print('Training set size:', len(x_train)) + + return x_train + +def main(): + """Main function.""" + if CONFIG['exp']['model'] not in ('musegan', 'bmusegan'): + raise ValueError("Unrecognizable model name") + + print("Start experiment: {}".format(CONFIG['exp']['exp_name'])) + + # Load training data + x_train = load_data() + + # Open TensorFlow session + with tf.Session(config=CONFIG['tensorflow']) as sess: + + # ============================== MuseGAN =============================== + if CONFIG['exp']['model'] == 'musegan': + + # Create model + gan = MODELS.GAN(sess, CONFIG['model']) + + # Initialize all variables + gan.init_all() + + # Load pretrained model if given + if CONFIG['exp']['pretrained_dir'] is not None: + gan.load_latest(CONFIG['exp']['pretrained_dir']) - """ Create TensorFlow Session """ + # Train the model + gan.train(x_train, CONFIG['train']) - t_config = TrainingConfig + # =========================== BinaryMuseGAN ============================ + elif CONFIG['exp']['model'] == 'bmusegan': - os.environ['CUDA_VISIBLE_DEVICES'] = t_config.gpu_num - config = tf.ConfigProto() - config.gpu_options.allow_growth = True + # ------------------------ Two-stage model ------------------------- + if CONFIG['exp']['two_stage_training']: - with tf.Session(config=config) as sess: + # Create model + gan = MODELS.GAN(sess, CONFIG['model']) - path_x_train_phr = 'tra_X_phrase_all' # (50266, 384, 84, 5) + # Initialize all variables + gan.init_all() - # Temporal - # hybrid - t_config.exp_name = 'exps/temporal_hybrid' - model = TemporalHybrid(TemporalHybridConfig) - input_data = InputDataTemporalHybrid(model) - input_data.add_data_sa(path_x_train_phr, 'train') + # First stage training + if CONFIG['train']['training_phase'] == 'first_stage': - musegan = MuseGAN(sess, t_config, model) - musegan.train(input_data) + # Load pretrained model if given + if CONFIG['exp']['pretrained_dir'] is not None: + gan.load_latest(CONFIG['exp']['pretrained_dir']) - musegan.load(musegan.dir_ckpt) - musegan.gen_test(input_data, is_eval=True) + # Train the model + gan.train(x_train, CONFIG['train']) + # Second stage training + if CONFIG['train']['training_phase'] == 'two_stage': + # Load first-stage pretrained model + gan.load_latest(CONFIG['exp']['first_stage_dir']) + + refine_gan = MODELS.RefineGAN(sess, CONFIG['model'], gan) + + # Initialize all variables + refine_gan.init_all() + + # Load pretrained model if given + if CONFIG['exp']['pretrained_dir'] is not None: + refine_gan.load_latest(CONFIG['exp']['pretrained_dir']) + + # Train the model + refine_gan.train(x_train, CONFIG['train']) + + # ------------------------ End-to-end model ------------------------ + else: + # Create model + end2end_gan = MODELS.End2EndGAN(sess, CONFIG['model']) + + # Initialize all variables + end2end_gan.init_all() + + # Load pretrained model if given + if CONFIG['exp']['pretrained_dir'] is not None: + end2end_gan.load_latest(CONFIG['exp']['pretrained_dir']) + + # Train the model + end2end_gan.train(x_train, CONFIG['train']) + +if __name__ == '__main__': + main() diff --git a/musegan/__init__.py b/musegan/__init__.py index e69de29b..8b137891 100644 --- a/musegan/__init__.py +++ b/musegan/__init__.py @@ -0,0 +1 @@ + diff --git a/musegan/bmusegan/__init__.py b/musegan/bmusegan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/bmusegan/components.py b/musegan/bmusegan/components.py new file mode 100644 index 00000000..fbbf9b3a --- /dev/null +++ b/musegan/bmusegan/components.py @@ -0,0 +1,208 @@ +"""Classes that define the generator, the discriminator and the refiner. +""" +from collections import OrderedDict +import tensorflow as tf +from musegan.component import Component +from musegan.utils.neuralnet import NeuralNet + +class Generator(Component): + """Class that defines the generator.""" + def __init__(self, tensor_in, config, condition=None, name='Generator', + reuse=None): + super().__init__(tensor_in, condition) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the generator.""" + nets = OrderedDict() + + nets['shared'] = NeuralNet(self.tensor_in, config['net_g']['shared'], + name='shared') + + nets['pitch_time_private'] = [ + NeuralNet(nets['shared'].tensor_out, + config['net_g']['pitch_time_private'], + name='pt_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['time_pitch_private'] = [ + NeuralNet(nets['shared'].tensor_out, + config['net_g']['time_pitch_private'], + name='tp_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['merged_private'] = [ + NeuralNet(tf.concat([nets['pitch_time_private'][idx].tensor_out, + nets['time_pitch_private'][idx].tensor_out], + -1), + config['net_g']['merged_private'], + name='merged_'+str(idx)) + for idx in range(config['num_track']) + ] + + tensor_out = tf.concat([nn.tensor_out for nn in nets['merged_private']], + -1) + return tensor_out, nets + +class Discriminator(Component): + """Class that defines the discriminator.""" + def __init__(self, tensor_in, config, condition=None, name='Discriminator', + reuse=None): + super().__init__(tensor_in, condition) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the discriminator.""" + nets = OrderedDict() + + # main stream + nets['pitch_time_private'] = [ + NeuralNet(tf.expand_dims(self.tensor_in[..., idx], -1), + config['net_d']['pitch_time_private'], + name='pt_' + str(idx)) + for idx in range(config['num_track']) + ] + + nets['time_pitch_private'] = [ + NeuralNet(tf.expand_dims(self.tensor_in[..., idx], -1), + config['net_d']['time_pitch_private'], + name='tp_' + str(idx)) + for idx in range(config['num_track']) + ] + + nets['merged_private'] = [ + NeuralNet( + tf.concat([x.tensor_out, + nets['time_pitch_private'][idx].tensor_out], -1), + config['net_d']['merged_private'], name='merged_' + str(idx)) + for idx, x in enumerate(nets['pitch_time_private']) + ] + + nets['shared'] = NeuralNet( + tf.concat([nn.tensor_out for nn in nets['merged_private']], -1), + config['net_d']['shared'], name='shared' + ) + + # chroma stream + reshaped = tf.reshape( + self.tensor_in, (-1, config['num_bar'], config['num_beat'], + config['beat_resolution'], config['num_pitch']//12, + 12, config['num_track']) + ) + self.chroma = tf.reduce_sum(reshaped, axis=(3, 4)) + nets['chroma'] = NeuralNet(self.chroma, config['net_d']['chroma'], + name='chroma') + + # onset stream + padded = tf.pad(self.tensor_in[:, :, :-1, :, 1:], + [[0, 0], [0, 0], [1, 0], [0, 0], [0, 0]]) + self.onset = tf.concat([tf.expand_dims(self.tensor_in[..., 0], -1), + self.tensor_in[..., 1:] - padded], -1) + nets['onset'] = NeuralNet(self.onset, config['net_d']['onset'], + name='onset') + + if (config['net_d']['chroma'] is not None + or config['net_d']['onset'] is not None): + to_concat = [nets['shared'].tensor_out] + if config['net_d']['chroma'] is not None: + to_concat.append(nets['chroma'].tensor_out) + if config['net_d']['onset'] is not None: + to_concat.append(nets['onset'].tensor_out) + concated = tf.concat(to_concat, -1) + else: + concated = nets['shared'].tensor_out + + # merge streams + nets['merged'] = NeuralNet(concated, config['net_d']['merged'], + name='merged') + + return nets['merged'].tensor_out, nets + +class Refiner(Component): + """Class that defines the refiner.""" + def __init__(self, tensor_in, config, condition=None, slope_tensor=None, + name='Refiner', reuse=None): + super().__init__(tensor_in, condition, slope_tensor) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets, self.preactivated = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the refiner.""" + nets = OrderedDict() + + nets['private'] = [ + NeuralNet(tf.expand_dims(self.tensor_in[..., idx], -1), + config['net_r']['private'], + slope_tensor=self.slope_tensor, name='private'+str(idx)) + for idx in range(config['num_track']) + ] + + return (tf.concat([nn.tensor_out for nn in nets['private']], -1), nets, + tf.concat([nn.layers[-1].preactivated + for nn in nets['private']], -1)) + +class End2EndGenerator(Component): + """Class that defines the end-to-end generator.""" + def __init__(self, tensor_in, config, condition=None, slope_tensor=None, + name='End2EndGenerator', reuse=None): + super().__init__(tensor_in, condition, slope_tensor) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets, self.preactivated = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the end-to-end generator.""" + nets = OrderedDict() + + nets['shared'] = NeuralNet(self.tensor_in, config['net_g']['shared'], + name='shared') + + nets['pitch_time_private'] = [ + NeuralNet(nets['shared'].tensor_out, + config['net_g']['pitch_time_private'], + name='pt_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['time_pitch_private'] = [ + NeuralNet(nets['shared'].tensor_out, + config['net_g']['time_pitch_private'], + name='tp_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['merged_private'] = [ + NeuralNet(tf.concat([nets['pitch_time_private'][idx].tensor_out, + nets['time_pitch_private'][idx].tensor_out], + -1), + config['net_g']['merged_private'], + name='merged_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['refiner_private'] = [ + NeuralNet(nets['merged_private'][idx].tensor_out, + config['net_r']['private'], + slope_tensor=self.slope_tensor, + name='refiner_private'+str(idx)) + for idx in range(config['num_track']) + ] + + return (tf.concat([nn.tensor_out for nn in nets['private']], -1), nets, + tf.concat([nn.layers[-1].preactivated + for nn in nets['private']], -1)) diff --git a/musegan/bmusegan/models.py b/musegan/bmusegan/models.py new file mode 100644 index 00000000..2bac7589 --- /dev/null +++ b/musegan/bmusegan/models.py @@ -0,0 +1,617 @@ +"""Classes that define the GAN, RefinerGAN and End2EndGAN models. +""" +import os.path +import time +import numpy as np +import tensorflow as tf +from musegan.model import Model +from musegan.bmusegan.components import Discriminator, Generator, Refiner +from musegan.bmusegan.components import End2EndGenerator +from musegan.utils.metrics import Metrics + +class GAN(Model): + """Class that defines the first-stage (without refiner) model.""" + def __init__(self, sess, config, name='GAN', reuse=None): + super().__init__(sess, config, name) + + print('[*] Building GAN...') + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.build() + + def build(self): + """Build the model.""" + self.global_step = tf.Variable(0, trainable=False, name='global_step') + + # Create placeholders + self.z = tf.placeholder( + tf.float32, + (self.config['batch_size'], self.config['net_g']['z_dim']), 'z' + ) + data_shape = (self.config['batch_size'], self.config['num_bar'], + self.config['num_timestep'], self.config['num_pitch'], + self.config['num_track']) + self.x = tf.placeholder(tf.bool, data_shape, 'x') + self.x_ = tf.cast(self.x, tf.float32, 'x_') + + # Components + self.G = Generator(self.z, self.config, name='G') + self.test_round = self.G.tensor_out > 0.5 + self.test_bernoulli = self.G.tensor_out > tf.random_uniform(data_shape) + + self.D_fake = Discriminator(self.G.tensor_out, self.config, name='D') + self.D_real = Discriminator(self.x_, self.config, name='D', reuse=True) + self.components = (self.G, self.D_fake) + + # Losses + self.g_loss, self.d_loss = self.get_adversarial_loss(Discriminator) + + # Optimizers + with tf.variable_scope('Optimizer'): + self.g_optimizer = self.get_optimizer() + self.g_step = self.g_optimizer.minimize( + self.g_loss, self.global_step, self.G.vars) + + self.d_optimizer = self.get_optimizer() + self.d_step = self.d_optimizer.minimize( + self.d_loss, self.global_step, self.D_fake.vars) + + # Apply weight clipping + if self.config['gan']['type'] == 'wgan': + with tf.control_dependencies([self.d_step]): + self.d_step = tf.group( + *(tf.assign(var, tf.clip_by_value( + var, -self.config['gan']['clip_value'], + self.config['gan']['clip_value'])) + for var in self.D_fake.vars)) + + # Metrics + self.metrics = Metrics(self.config) + + # Saver + self.saver = tf.train.Saver() + + # Print and save model information + self.print_statistics() + self.save_statistics() + self.print_summary() + self.save_summary() + + def train(self, x_train, train_config): + """Train the model.""" + # Initialize sampler + self.z_sample = np.random.normal( + size=(self.config['batch_size'], self.config['net_g']['z_dim'])) + self.x_sample = x_train[np.random.choice( + len(x_train), self.config['batch_size'], False)] + feed_dict_sample = {self.x: self.x_sample, self.z: self.z_sample} + + # Save samples + self.save_samples('x_train', x_train, save_midi=True) + self.save_samples('x_sample', self.x_sample, save_midi=True) + + # Open log files and write headers + log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') + log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') + log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') + log_step.write('# epoch, step, negative_critic_loss\n') + log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') + log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') + + # Initialize counter + counter = 0 + num_batch = len(x_train) // self.config['batch_size'] + + # Start epoch iteration + print('{:=^80}'.format(' Training Start ')) + for epoch in range(train_config['num_epoch']): + + print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) + epoch_start_time = time.time() + + # Prepare batched training data + z_random_batch = np.random.normal( + size=(num_batch, self.config['batch_size'], + self.config['net_g']['z_dim']) + ) + x_random_batch = np.random.choice( + len(x_train), (num_batch, self.config['batch_size']), False + ) + + # Start batch iteration + for batch in range(num_batch): + + feed_dict_batch = {self.x: x_train[x_random_batch[batch]], + self.z: z_random_batch[batch]} + + if (counter < 25) or (counter % 500 == 0): + num_critics = 100 + else: + num_critics = 5 + + batch_start_time = time.time() + + # Update networks + for _ in range(num_critics): + _, d_loss = self.sess.run([self.d_step, self.d_loss], + feed_dict_batch) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + _, d_loss, g_loss = self.sess.run( + [self.g_step, self.d_loss, self.g_loss], feed_dict_batch + ) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + time_batch = time.time() - batch_start_time + + # Print iteration summary + if train_config['verbose']: + if batch < 1: + print("epoch | batch | time | - D_loss |" + " G_loss") + print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " + "{:14.6f}".format(epoch, batch, num_batch, time_batch, + -d_loss, g_loss)) + + log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( + epoch, batch, time_batch, -d_loss, g_loss + )) + + # run sampler + if train_config['sample_along_training']: + if counter%100 == 0 or (counter < 300 and counter%20 == 0): + self.run_sampler(self.G.tensor_out, feed_dict_sample, + False) + self.run_sampler(self.test_round, feed_dict_sample, + (counter > 500), postfix='test_round') + self.run_sampler(self.test_bernoulli, feed_dict_sample, + (counter > 500), + postfix='test_bernoulli') + + # run evaluation + if train_config['evaluate_along_training']: + if counter%10 == 0: + self.run_eval(self.test_round, feed_dict_sample, + postfix='test_round') + self.run_eval(self.test_bernoulli, feed_dict_sample, + postfix='test_bernoulli') + + counter += 1 + + # print epoch info + time_epoch = time.time() - epoch_start_time + + if not train_config['verbose']: + if epoch < 1: + print("epoch | time | - D_loss | G_loss") + print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( + epoch, time_epoch, -d_loss, g_loss)) + + log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( + epoch, time_epoch, -d_loss, g_loss + )) + + # save checkpoints + self.save() + + print('{:=^80}'.format(' Training End ')) + log_step.close() + log_batch.close() + log_epoch.close() + +class RefineGAN(Model): + """Class that defines the second-stage (with refiner) model.""" + def __init__(self, sess, config, pretrained, name='RefineGAN', reuse=None): + super().__init__(sess, config, name) + self.pretrained = pretrained + + print('[*] Building RefineGAN...') + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.build() + + def build(self): + """Build the model.""" + # Create global step variable + self.global_step = tf.Variable(0, trainable=False, name='global_step') + + # Get tensors from the pretrained model + self.z = self.pretrained.z + self.x = self.pretrained.x + self.x_ = self.pretrained.x_ + + # Slope tensor for applying slope annealing trick to stochastic neurons + self.slope_tensor = tf.Variable(1.0) + + # Components + self.G = Refiner(self.pretrained.G.tensor_out, self.config, + slope_tensor=self.slope_tensor, name='R') + self.D_real = self.pretrained.D_real + with tf.variable_scope(self.pretrained.scope, reuse=True): + self.D_fake = Discriminator(self.G.tensor_out, self.config, + name='D') + self.components = (self.pretrained.G, self.G, self.D_fake) + + # Losses + self.g_loss, self.d_loss = self.get_adversarial_loss( + Discriminator, self.pretrained.scope) + + # Optimizers + with tf.variable_scope('Optimizer'): + self.g_optimizer = self.get_optimizer() + if self.config['joint_training']: + self.g_step = self.g_optimizer.minimize( + self.g_loss, self.global_step, (self.G.vars + + self.pretrained.G.vars)) + else: + self.g_step = self.g_optimizer.minimize( + self.g_loss, self.global_step, self.G.vars) + self.d_optimizer = self.get_optimizer() + self.d_step = self.d_optimizer.minimize( + self.d_loss, self.global_step, self.D_fake.vars) + + # Apply weight clipping + if self.config['gan']['type'] == 'wgan': + with tf.control_dependencies([self.d_step]): + self.d_step = tf.group( + *(tf.assign(var, tf.clip_by_value( + var, -self.config['gan']['clip_value'], + self.config['gan']['clip_value'])) + for var in self.D_fake.vars)) + + # Metrics + self.metrics = Metrics(self.config) + + # Saver + self.saver = tf.train.Saver() + + # Print and save model information + self.print_statistics() + self.save_statistics() + self.print_summary() + self.save_summary() + + def train(self, x_train, train_config): + """Train the model.""" + # Initialize sampler + self.z_sample = np.random.normal( + size=(self.config['batch_size'], self.config['net_g']['z_dim'])) + self.x_sample = x_train[np.random.choice( + len(x_train), self.config['batch_size'], False)] + feed_dict_sample = {self.x: self.x_sample, self.z: self.z_sample} + + # Save samples + self.save_samples('x_train', x_train, save_midi=True) + self.save_samples('x_sample', self.x_sample, save_midi=True) + + pretrained_samples = self.sess.run(self.pretrained.G.tensor_out, + feed_dict_sample) + self.save_samples('pretrained', pretrained_samples) + + for threshold in [0.1, 0.3, 0.5, 0.7, 0.9]: + pretrained_threshold = (pretrained_samples > threshold) + self.save_samples('pretrained_threshold_{}'.format(threshold), + pretrained_threshold, save_midi=True) + + for idx in range(5): + pretrained_bernoulli = np.ceil( + pretrained_samples + - np.random.uniform(size=pretrained_samples.shape)) + self.save_samples('pretrained_bernoulli_{}'.format(idx), + pretrained_bernoulli, save_midi=True) + + # Open log files and write headers + log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') + log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') + log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') + log_step.write('# epoch, step, negative_critic_loss\n') + log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') + log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') + + # Define slope annealing op + if train_config['slope_annealing_rate'] != 1.: + slope_annealing_op = tf.assign( + self.slope_tensor, + self.slope_tensor * train_config['slope_annealing_rate']) + + # Initialize counter + counter = 0 + num_batch = len(x_train) // self.config['batch_size'] + + # Start epoch iteration + print('{:=^80}'.format(' Training Start ')) + for epoch in range(train_config['num_epoch']): + + print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) + epoch_start_time = time.time() + + # Prepare batched training data + z_random_batch = np.random.normal( + size=(num_batch, self.config['batch_size'], + self.config['net_g']['z_dim']) + ) + x_random_batch = np.random.choice( + len(x_train), (num_batch, self.config['batch_size']), False) + + # Start batch iteration + for batch in range(num_batch): + + feed_dict_batch = {self.x: x_train[x_random_batch[batch]], + self.z: z_random_batch[batch]} + + if counter % 500 == 0: # (counter < 25) + num_critics = 100 + else: + num_critics = 5 + + batch_start_time = time.time() + + # Update networks + for _ in range(num_critics): + _, d_loss = self.sess.run([self.d_step, self.d_loss], + feed_dict_batch) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + _, d_loss, g_loss = self.sess.run( + [self.g_step, self.d_loss, self.g_loss], feed_dict_batch + ) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + time_batch = time.time() - batch_start_time + + # Print iteration summary + if train_config['verbose']: + if batch < 1: + print("epoch | batch | time | - D_loss |" + " G_loss") + print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " + "{:14.6f}".format(epoch, batch, num_batch, time_batch, + -d_loss, g_loss)) + + log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( + epoch, batch, time_batch, -d_loss, g_loss + )) + + # run sampler + if train_config['sample_along_training']: + if counter%100 == 0 or (counter < 300 and counter%20 == 0): + self.run_sampler(self.G.tensor_out, feed_dict_sample, + (counter > 500)) + self.run_sampler(self.G.preactivated, feed_dict_sample, + False, postfix='preactivated') + + # run evaluation + if train_config['evaluate_along_training']: + if counter%10 == 0: + self.run_eval(self.G.tensor_out, feed_dict_sample) + + counter += 1 + + # print epoch info + time_epoch = time.time() - epoch_start_time + + if not train_config['verbose']: + if epoch < 1: + print("epoch | time | - D_loss | G_loss") + print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( + epoch, time_epoch, -d_loss, g_loss)) + + log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( + epoch, time_epoch, -d_loss, g_loss + )) + + # save checkpoints + self.save() + + if train_config['slope_annealing_rate'] != 1.: + self.sess.run(slope_annealing_op) + + print('{:=^80}'.format(' Training End ')) + log_step.close() + log_batch.close() + log_epoch.close() + +class End2EndGAN(Model): + """Class that defines the end-to-end model.""" + def __init__(self, sess, config, name='End2EndGAN', reuse=None): + super().__init__(sess, config, name) + + print('[*] Building End2EndGAN...') + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.build() + + def build(self): + """Build the model.""" + self.global_step = tf.Variable(0, trainable=False, name='global_step') + + # Create placeholders + self.z = tf.placeholder( + tf.float32, + (self.config['batch_size'], self.config['net_g']['z_dim']), 'z' + ) + data_shape = (self.config['batch_size'], self.config['num_bar'], + self.config['num_timestep'], self.config['num_pitch'], + self.config['num_track']) + self.x = tf.placeholder(tf.bool, data_shape, 'x') + self.x_ = tf.cast(self.x, tf.float32, 'x_') + + # Slope tensor for applying slope annealing trick to stochastic neurons + self.slope_tensor = tf.Variable(1.0) + + # Components + self.G = End2EndGenerator(self.z, self.config, + slope_tensor=self.slope_tensor, name='G') + self.D_fake = Discriminator(self.G.tensor_out, self.config, name='D') + self.D_real = Discriminator(self.x_, self.config, name='D', reuse=True) + self.components = (self.G, self.D_fake) + + # Losses + self.g_loss, self.d_loss = self.get_adversarial_loss(Discriminator) + + # Optimizers + with tf.variable_scope('Optimizer'): + self.g_optimizer = self.get_optimizer() + self.g_step = self.g_optimizer.minimize( + self.g_loss, self.global_step, self.G.vars) + + self.d_optimizer = self.get_optimizer() + self.d_step = self.d_optimizer.minimize( + self.d_loss, self.global_step, self.D_fake.vars) + + # Apply weight clipping + if self.config['gan']['type'] == 'wgan': + with tf.control_dependencies([self.d_step]): + self.d_step = tf.group( + *(tf.assign(var, tf.clip_by_value( + var, -self.config['gan']['clip_value'], + self.config['gan']['clip_value'])) + for var in self.D_fake.vars)) + + # Metrics + self.metrics = Metrics(self.config) + + # Saver + self.saver = tf.train.Saver() + + # Print and save model information + self.print_statistics() + self.save_statistics() + self.print_summary() + self.save_summary() + + def train(self, x_train, train_config): + """Train the model.""" + # Initialize sampler + self.z_sample = np.random.normal( + size=(self.config['batch_size'], self.config['net_g']['z_dim'])) + self.x_sample = x_train[np.random.choice( + len(x_train), self.config['batch_size'], False)] + feed_dict_sample = {self.x: self.x_sample, self.z: self.z_sample} + + # Save samples + self.save_samples('x_train', x_train, save_midi=True) + self.save_samples('x_sample', self.x_sample, save_midi=True) + + # Open log files and write headers + log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') + log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') + log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') + log_step.write('# epoch, step, negative_critic_loss\n') + log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') + log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') + + # Define slope annealing op + if train_config['slope_annealing_rate'] != 1.: + slope_annealing_op = tf.assign( + self.slope_tensor, + self.slope_tensor * train_config['slope_annealing_rate']) + + # Initialize counter + counter = 0 + num_batch = len(x_train) // self.config['batch_size'] + + # Start epoch iteration + print('{:=^80}'.format(' Training Start ')) + for epoch in range(train_config['num_epoch']): + + print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) + epoch_start_time = time.time() + + # Prepare batched training data + z_random_batch = np.random.normal( + size=(num_batch, self.config['batch_size'], + self.config['net_g']['z_dim']) + ) + x_random_batch = np.random.choice( + len(x_train), (num_batch, self.config['batch_size']), False) + + # Start batch iteration + for batch in range(num_batch): + + feed_dict_batch = {self.x: x_train[x_random_batch[batch]], + self.z: z_random_batch[batch]} + + if (counter < 25) or (counter % 500 == 0): + num_critics = 100 + else: + num_critics = 5 + + batch_start_time = time.time() + + # Update networks + for _ in range(num_critics): + _, d_loss = self.sess.run([self.d_step, self.d_loss], + feed_dict_batch) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + _, d_loss, g_loss = self.sess.run( + [self.g_step, self.d_loss, self.g_loss], feed_dict_batch + ) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + time_batch = time.time() - batch_start_time + + # Print iteration summary + if train_config['verbose']: + if batch < 1: + print("epoch | batch | time | - D_loss |" + " G_loss") + print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " + "{:14.6f}".format(epoch, batch, num_batch, time_batch, + -d_loss, g_loss)) + + log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( + epoch, batch, time_batch, -d_loss, g_loss + )) + + # run sampler + if train_config['sample_along_training']: + if counter%100 == 0 or (counter < 300 and counter%20 == 0): + self.run_sampler(self.G.tensor_out, feed_dict_sample, + (counter > 500)) + self.run_sampler(self.G.preactivated, feed_dict_sample, + False, postfix='preactivated') + + # run evaluation + if train_config['evaluate_along_training']: + if counter%10 == 0: + self.run_eval(self.G.tensor_out, feed_dict_sample) + + counter += 1 + + # print epoch info + time_epoch = time.time() - epoch_start_time + + if not train_config['verbose']: + if epoch < 1: + print("epoch | time | - D_loss | G_loss") + print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( + epoch, time_epoch, -d_loss, g_loss)) + + log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( + epoch, time_epoch, -d_loss, g_loss + )) + + # save checkpoints + self.save() + + if train_config['slope_annealing_rate'] != 1.: + self.sess.run(slope_annealing_op) + + print('{:=^80}'.format(' Training End ')) + log_step.close() + log_batch.close() + log_epoch.close() diff --git a/musegan/bmusegan/presets/__init__.py b/musegan/bmusegan/presets/__init__.py new file mode 100644 index 00000000..e90ccdf9 --- /dev/null +++ b/musegan/bmusegan/presets/__init__.py @@ -0,0 +1,2 @@ +"""Presets of network architectures for model components +""" diff --git a/musegan/bmusegan/presets/discriminator/__init__.py b/musegan/bmusegan/presets/discriminator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/bmusegan/presets/discriminator/ablated.py b/musegan/bmusegan/presets/discriminator/ablated.py new file mode 100644 index 00000000..e4ab7c28 --- /dev/null +++ b/musegan/bmusegan/presets/discriminator/ablated.py @@ -0,0 +1,33 @@ +"""Network architecture of the ablated (without onset and chroma streams) +discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (64, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (64, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 7) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = None + +NET_D['chroma'] = None + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/bmusegan/presets/discriminator/baseline.py b/musegan/bmusegan/presets/discriminator/baseline.py new file mode 100644 index 00000000..e4d67cb0 --- /dev/null +++ b/musegan/bmusegan/presets/discriminator/baseline.py @@ -0,0 +1,27 @@ +"""Network architecture of the baseline discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = None + +NET_D['time_pitch_private'] = None + +NET_D['merged_private'] = None + +NET_D['shared'] = None + +NET_D['onset'] = None + +NET_D['chroma'] = None + +NET_D['merged'] = [ + ('conv3d', (128, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (128, (1, 1, 3), (1, 1, 2)), None, 'lrelu'), # 1 (4, 96, 3) + ('conv3d', (256, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 2 (4, 16, 3) + ('conv3d', (256, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 4, 3) + ('conv3d', (512, (1, 1, 3), (1, 1, 3)), None, 'lrelu'), # 4 (4, 4, 1) + ('conv3d', (512, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 5 (4, 1, 1) + ('conv3d', (1024, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 6 (3, 1, 1) + ('reshape', (3*1024)), + ('dense', 1), +] diff --git a/musegan/bmusegan/presets/discriminator/proposed.py b/musegan/bmusegan/presets/discriminator/proposed.py new file mode 100644 index 00000000..cfb32b41 --- /dev/null +++ b/musegan/bmusegan/presets/discriminator/proposed.py @@ -0,0 +1,40 @@ +"""Network architecture of the proposed discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (64, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (64, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 7) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = [ + ('sum', (3), True), # 0 (4, 96, 1) + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 2 (4, 4, 1) + ('conv3d', (128, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 1, 1) +] + +NET_D['chroma'] = [ + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 4, 1) + ('conv3d', (128, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/bmusegan/presets/discriminator/proposed_small.py b/musegan/bmusegan/presets/discriminator/proposed_small.py new file mode 100644 index 00000000..ba7e76f4 --- /dev/null +++ b/musegan/bmusegan/presets/discriminator/proposed_small.py @@ -0,0 +1,40 @@ +"""Network architecture of the proposed discriminator (with less filters) +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (16, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (16, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (32, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 84) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = [ + ('sum', (3), True), # 0 (4, 96, 1) + ('conv3d', (16, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 1) + ('conv3d', (32, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 2 (4, 4, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 1, 1) +] + +NET_D['chroma'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 4, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/bmusegan/presets/generator/__init__.py b/musegan/bmusegan/presets/generator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/bmusegan/presets/generator/proposed.py b/musegan/bmusegan/presets/generator/proposed.py new file mode 100644 index 00000000..1236f8dc --- /dev/null +++ b/musegan/bmusegan/presets/generator/proposed.py @@ -0,0 +1,29 @@ +"""Network architecture of the proposed generator. +""" +NET_G = {} + +NET_G['z_dim'] = 128 + +NET_G['shared'] = [ + ('dense', (3*512), 'bn', 'relu'), # 0 + ('reshape', (3, 1, 1, 512)), # 1 (3, 1, 1) + ('transconv3d', (256, (2, 1, 1), (1, 1, 1)), 'bn', 'relu'), # 2 (4, 1, 1) + ('transconv3d', (128, (1, 4, 1), (1, 4, 1)), 'bn', 'relu'), # 3 (4, 4, 1) + ('transconv3d', (128, (1, 1, 3), (1, 1, 3)), 'bn', 'relu'), # 4 (4, 4, 3) + ('transconv3d', (64, (1, 4, 1), (1, 4, 1)), 'bn', 'relu'), # 5 (4, 16, 3) + ('transconv3d', (64, (1, 1, 3), (1, 1, 2)), 'bn', 'relu'), # 6 (4, 16, 7) +] + +NET_G['pitch_time_private'] = [ + ('transconv3d', (64, (1, 1, 12), (1, 1, 12)), 'bn', 'relu'),# 0 (4, 16, 84) + ('transconv3d', (32, (1, 6, 1), (1, 6, 1)), 'bn', 'relu'), # 1 (4, 96, 84) +] + +NET_G['time_pitch_private'] = [ + ('transconv3d', (64, (1, 6, 1), (1, 6, 1)), 'bn', 'relu'), # 0 (4, 96, 7) + ('transconv3d', (32, (1, 1, 12), (1, 1, 12)), 'bn', 'relu'),# 1 (4, 96, 84) +] + +NET_G['merged_private'] = [ + ('transconv3d', (1, (1, 1, 1), (1, 1, 1)), 'bn', 'sigmoid'),# 0 (4, 96, 84) +] diff --git a/musegan/bmusegan/presets/generator/proposed_small.py b/musegan/bmusegan/presets/generator/proposed_small.py new file mode 100644 index 00000000..e9813a16 --- /dev/null +++ b/musegan/bmusegan/presets/generator/proposed_small.py @@ -0,0 +1,29 @@ +"""Network architecture of the proposed generator. +""" +NET_G = {} + +NET_G['z_dim'] = 128 + +NET_G['shared'] = [ + ('dense', (3*256), 'bn', 'relu'), # 0 + ('reshape', (3, 1, 1, 256)), # 1 (3, 1, 1) + ('transconv3d', (256, (2, 1, 1), (1, 1, 1)), 'bn', 'relu'), # 2 (4, 1, 1) + ('transconv3d', (128, (1, 4, 1), (1, 4, 1)), 'bn', 'relu'), # 3 (4, 4, 1) + ('transconv3d', (128, (1, 1, 3), (1, 1, 3)), 'bn', 'relu'), # 4 (4, 4, 3) + ('transconv3d', (64, (1, 4, 1), (1, 4, 1)), 'bn', 'relu'), # 5 (4, 16, 3) + ('transconv3d', (64, (1, 1, 3), (1, 1, 2)), 'bn', 'relu'), # 6 (4, 16, 7) +] + +NET_G['pitch_time_private'] = [ + ('transconv3d', (64, (1, 1, 12), (1, 1, 12)), 'bn', 'relu'),# 0 (4, 16, 84) + ('transconv3d', (32, (1, 6, 1), (1, 6, 1)), 'bn', 'relu'), # 1 (4, 96, 84) +] + +NET_G['time_pitch_private'] = [ + ('transconv3d', (64, (1, 6, 1), (1, 6, 1)), 'bn', 'relu'), # 0 (4, 96, 7) + ('transconv3d', (32, (1, 1, 12), (1, 1, 12)), 'bn', 'relu'),# 1 (4, 96, 84) +] + +NET_G['merged_private'] = [ + ('transconv3d', (1, (1, 1, 1), (1, 1, 1)), 'bn', 'sigmoid'),# 0 (4, 96, 84) +] diff --git a/musegan/bmusegan/presets/refiner/__init__.py b/musegan/bmusegan/presets/refiner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/bmusegan/presets/refiner/proposed_bernoulli.py b/musegan/bmusegan/presets/refiner/proposed_bernoulli.py new file mode 100644 index 00000000..10010a01 --- /dev/null +++ b/musegan/bmusegan/presets/refiner/proposed_bernoulli.py @@ -0,0 +1,15 @@ +"""Refiner built with residual blocks and bernoulli activation +""" +NET_R = {} + +NET_R['private'] = [ + ('identity', None, None, None), + ('identity', None, 'bn', 'relu'), + ('conv3d', (64, (1, 3, 12), (1, 1, 1), 'SAME'), 'bn', 'relu'), + ('conv3d', (1, (1, 3, 12), (1, 1, 1), 'SAME'), None, None), + ('identity', None, None, None, ('add', 0)), + ('identity', None, 'bn', 'relu'), + ('conv3d', (64, (1, 3, 12), (1, 1, 1), 'SAME'), 'bn', 'relu'), + ('conv3d', (1, (1, 3, 12), (1, 1, 1), 'SAME'), None, None), + ('identity', None, None, 'bernoulli', ('add', 4)), +] diff --git a/musegan/bmusegan/presets/refiner/proposed_round.py b/musegan/bmusegan/presets/refiner/proposed_round.py new file mode 100644 index 00000000..5bb6f7a8 --- /dev/null +++ b/musegan/bmusegan/presets/refiner/proposed_round.py @@ -0,0 +1,15 @@ +"""Refiner built with residual blocks and binary round activation +""" +NET_R = {} + +NET_R['private'] = [ + ('identity', None, None, None), + ('identity', None, 'bn', 'relu'), + ('conv3d', (64, (1, 3, 12), (1, 1, 1), 'SAME'), 'bn', 'relu'), + ('conv3d', (1, (1, 3, 12), (1, 1, 1), 'SAME'), None, None), + ('identity', None, None, None, ('add', 0)), + ('identity', None, 'bn', 'relu'), + ('conv3d', (64, (1, 3, 12), (1, 1, 1), 'SAME'), 'bn', 'relu'), + ('conv3d', (1, (1, 3, 12), (1, 1, 1), 'SAME'), None, None), + ('identity', None, None, 'round', ('add', 4)), +] diff --git a/musegan/component.py b/musegan/component.py new file mode 100644 index 00000000..3611b91c --- /dev/null +++ b/musegan/component.py @@ -0,0 +1,64 @@ +"""Base class for the components. +""" +from collections import OrderedDict +import tensorflow as tf +from musegan.utils.neuralnet import NeuralNet + +class Component(object): + """Base class for components.""" + def __init__(self, tensor_in, condition, slope_tensor=None): + if not isinstance(tensor_in, (tf.Tensor, list, dict)): + raise TypeError("`tensor_in` must be of tf.Tensor type or a list " + "(or dict) of tf.Tensor objects") + if isinstance(tensor_in, list): + for tensor in tensor_in: + if not isinstance(tensor, tf.Tensor): + raise TypeError("`tensor_in` must be of tf.Tensor type or " + "a list (or dict) of tf.Tensor objects") + if isinstance(tensor_in, dict): + for key in tensor_in: + if not isinstance(tensor_in[key], tf.Tensor): + raise TypeError("`tensor_in` must be of tf.Tensor type or " + "a list (or dict) of tf.Tensor objects") + + self.tensor_in = tensor_in + self.condition = condition + self.slope_tensor = slope_tensor + + self.scope = None + self.tensor_out = tensor_in + self.nets = OrderedDict() + self.vars = None + + def __repr__(self): + if isinstance(self.tensor_in, tf.Tensor): + input_shape = self.tensor_in.get_shape() + else: + input_shape = ', '.join([ + '{}: {}'.format(key, self.tensor_in[key].get_shape()) + for key in self.tensor_in]) + return "Component({}, input_shape={}, output_shape={})".format( + self.scope.name, input_shape, str(self.tensor_out.get_shape())) + + def get_summary(self): + """Return the summary string.""" + cleansed_nets = [] + for net in self.nets.values(): + if isinstance(net, NeuralNet): + if net.scope is not None: + cleansed_nets.append(net) + if isinstance(net, list): + if net[0].scope is not None: + cleansed_nets.append(net[0]) + + if isinstance(self.tensor_in, tf.Tensor): + input_strs = ["{:50}{}".format('Input', self.tensor_in.get_shape())] + else: + input_strs = ["{:50}{}".format('Input - ' + key, + self.tensor_in[key].get_shape()) + for key in self.tensor_in] + + return '\n'.join( + ["{:-^80}".format(' ' + self.scope.name + ' ')] + input_strs + + ['-' * 80 + '\n' + x.get_summary() for x in cleansed_nets] + ) diff --git a/musegan/model.py b/musegan/model.py new file mode 100644 index 00000000..672c734f --- /dev/null +++ b/musegan/model.py @@ -0,0 +1,198 @@ +"""Base class for models. +""" +import os.path +import numpy as np +import tensorflow as tf +from musegan.utils import midi_io +from musegan.utils import image_io + +class Model(object): + """Base class for models.""" + def __init__(self, sess, config, name='model'): + self.sess = sess + self.name = name + self.config = config + + self.scope = None + self.global_step = None + self.x_ = None + self.G = None + self.D_real = None + self.D_fake = None + self.components = [] + self.metrics = None + self.saver = None + + def init_all(self): + """Initialize all variables in the scope.""" + print('[*] Initializing variables...') + tf.variables_initializer(tf.global_variables(self.scope.name)).run() + + def get_adversarial_loss(self, discriminator, scope_to_reuse=None): + """Return the adversarial losses for the generator and the + discriminator.""" + if self.config['gan']['type'] == 'gan': + adv_loss_d = tf.losses.sigmoid_cross_entropy( + tf.ones_like(self.D_real.tensor_out), + self.D_real.tensor_out) + adv_loss_g = tf.losses.sigmoid_cross_entropy( + tf.zeros_like(self.D_fake.tensor_out), + self.D_fake.tensor_out) + + if (self.config['gan']['type'] == 'wgan' + or self.config['gan']['type'] == 'wgan-gp'): + adv_loss_d = (tf.reduce_mean(self.D_fake.tensor_out) + - tf.reduce_mean(self.D_real.tensor_out)) + adv_loss_g = -tf.reduce_mean(self.D_fake.tensor_out) + + if self.config['gan']['type'] == 'wgan-gp': + eps = tf.random_uniform( + [tf.shape(self.x_)[0], 1, 1, 1, 1], 0.0, 1.0) + inter = eps * self.x_ + (1. - eps) * self.G.tensor_out + if scope_to_reuse is None: + D_inter = discriminator(inter, self.config, name='D', + reuse=True) + else: + with tf.variable_scope(scope_to_reuse, reuse=True): + D_inter = discriminator(inter, self.config, name='D', + reuse=True) + gradient = tf.gradients(D_inter.tensor_out, inter)[0] + slopes = tf.sqrt(1e-8 + tf.reduce_sum( + tf.square(gradient), + tf.range(1, len(gradient.get_shape())))) + gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0)) + adv_loss_d += (self.config['gan']['gp_coefficient'] + * gradient_penalty) + + return adv_loss_g, adv_loss_d + + def get_optimizer(self): + """Return a Adam optimizer.""" + return tf.train.AdamOptimizer( + self.config['optimizer']['lr'], + self.config['optimizer']['beta1'], + self.config['optimizer']['beta2'], + self.config['optimizer']['epsilon']) + + def get_statistics(self): + """Return model statistics (number of paramaters for each component).""" + def get_num_parameter(var_list): + """Given the variable list, return the total number of parameters. + """ + return int(np.sum([np.product([x.value for x in var.get_shape()]) + for var in var_list])) + num_par = get_num_parameter(tf.trainable_variables( + self.scope.name)) + num_par_g = get_num_parameter(self.G.vars) + num_par_d = get_num_parameter(self.D_fake.vars) + return ("Number of parameters: {}\nNumber of parameters in G: {}\n" + "Number of parameters in D: {}".format(num_par, num_par_g, + num_par_d)) + + def get_summary(self): + """Return model summary.""" + return '\n'.join( + ["{:-^80}".format(' < ' + self.scope.name + ' > ')] + + [(x.get_summary() + '\n' + '-' * 80) for x in self.components]) + + def get_global_step_str(self): + """Return the global step as a string.""" + return str(tf.train.global_step(self.sess, self.global_step)) + + def print_statistics(self): + """Print model statistics (number of paramaters for each component).""" + print("{:=^80}".format(' Model Statistics ')) + print(self.get_statistics()) + + def print_summary(self): + """Print model summary.""" + print("{:=^80}".format(' Model Summary ')) + print(self.get_summary()) + + def save_statistics(self, filepath=None): + """Save model statistics to file. Default to save to the log directory + given as a global variable.""" + if filepath is None: + filepath = os.path.join(self.config['log_dir'], + 'model_statistics.txt') + with open(filepath, 'w') as f: + f.write(self.get_statistics()) + + def save_summary(self, filepath=None): + """Save model summary to file. Default to save to the log directory + given as a global variable.""" + if filepath is None: + filepath = os.path.join(self.config['log_dir'], 'model_summary.txt') + with open(filepath, 'w') as f: + f.write(self.get_summary()) + + def save(self, filepath=None): + """Save the model to a checkpoint file. Default to save to the log + directory given as a global variable.""" + if filepath is None: + filepath = os.path.join(self.config['checkpoint_dir'], + self.name + '.model') + print('[*] Saving checkpoint...') + self.saver.save(self.sess, filepath, self.global_step) + + def load(self, filepath): + """Load the model from the latest checkpoint in a directory.""" + print('[*] Loading checkpoint...') + self.saver.restore(self.sess, filepath) + + def load_latest(self, checkpoint_dir=None): + """Load the model from the latest checkpoint in a directory.""" + if checkpoint_dir is None: + checkpoint_dir = self.config['checkpoint_dir'] + print('[*] Loading checkpoint...') + checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) + if checkpoint_path is None: + raise ValueError("Checkpoint not found") + self.saver.restore(self.sess, checkpoint_path) + + def save_samples(self, filename, samples, save_midi=False, shape=None, + postfix=None): + """Save samples to an image file (and a MIDI file).""" + if shape is None: + shape = self.config['sample_grid'] + if len(samples) > self.config['num_sample']: + samples = samples[:self.config['num_sample']] + if postfix is None: + imagepath = os.path.join(self.config['sample_dir'], + '{}.png'.format(filename)) + else: + imagepath = os.path.join(self.config['sample_dir'], + '{}_{}.png'.format(filename, postfix)) + image_io.save_image(imagepath, samples, shape) + if save_midi: + binarized = (samples > 0) + midipath = os.path.join(self.config['sample_dir'], + '{}.mid'.format(filename)) + midi_io.save_midi(midipath, binarized, self.config) + + def run_sampler(self, targets, feed_dict, save_midi=False, postfix=None): + """Run the target operation with feed_dict and save the samples.""" + if not isinstance(targets, list): + targets = [targets] + results = self.sess.run(targets, feed_dict) + results = [result[:self.config['num_sample']] for result in results] + samples = np.stack(results, 1).reshape((-1,) + results[0].shape[1:]) + shape = [self.config['sample_grid'][0], + self.config['sample_grid'][1] * len(results)] + if postfix is None: + filename = self.get_global_step_str() + else: + filename = self.get_global_step_str() + '_' + postfix + self.save_samples(filename, samples, save_midi, shape) + + def run_eval(self, target, feed_dict, postfix=None): + """Run evaluation.""" + result = self.sess.run(target, feed_dict) + binarized = (result > 0) + if postfix is None: + filename = self.get_global_step_str() + else: + filename = self.get_global_step_str() + '_' + postfix + reshaped = binarized.reshape((-1,) + binarized.shape[2:]) + mat_path = os.path.join(self.config['eval_dir'], filename+'.npy') + _ = self.metrics.eval(reshaped, mat_path=mat_path) diff --git a/musegan/musegan/__init__.py b/musegan/musegan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/musegan/components.py b/musegan/musegan/components.py new file mode 100644 index 00000000..7aac85e4 --- /dev/null +++ b/musegan/musegan/components.py @@ -0,0 +1,218 @@ +"""Classes that define the generator and the discriminator. +""" +from collections import OrderedDict +import tensorflow as tf +from musegan.component import Component +from musegan.utils.neuralnet import NeuralNet + +class Generator(Component): + """Class that defines the generator.""" + def __init__(self, tensor_in, config, condition=None, name='Generator', + reuse=None): + super().__init__(tensor_in, condition) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the generator.""" + nets = OrderedDict() + + # Tile shared latent vector along time axis + if 'shared' in self.tensor_in: + tiled_shared = tf.reshape( + tf.tile(self.tensor_in['shared'], (1, 4)), + (-1, 4, self.tensor_in['shared'].get_shape()[1]) + ) + + # Define shared temporal generator + if 'temporal_shared' in self.tensor_in: + nets['temporal_shared'] = NeuralNet( + self.tensor_in['temporal_shared'], + config['net_g']['temporal_shared'], name='temporal_shared' + ) + + # Shared bar generator mode + if config['net_g']['bar_generator_type'] == 'shared': + if ('private' in self.tensor_in + or 'temporal_private' in self.tensor_in): + raise ValueError("Private latent vectors received for a shared" + "bar generator") + + # Get the final input for the bar generator + z_input = tf.concat([tiled_shared, + nets['temporal_shared'].tensor_out], -1) + + nets['bar_main'] = NeuralNet(z_input, config['net_g']['bar_main'], + name='bar_main') + + nets['bar_pitch_time'] = NeuralNet( + nets['bar_main'].tensor_out, config['net_g']['bar_pitch_time'], + name='bar_pitch_time' + ) + + nets['bar_time_pitch'] = NeuralNet( + nets['bar_main'].tensor_out, config['net_g']['bar_time_pitch'], + name='bar_time_pitch' + ) + + if config['net_g']['bar_merged'][-1][1][0] is None: + config['net_g']['bar_merged'][-1][1][0] = config['num_track'] + + nets['bar_merged'] = NeuralNet( + tf.concat([nets['bar_pitch_time'].tensor_out, + nets['bar_time_pitch'].tensor_out], -1), + config['net_g']['bar_merged'], name='bar_merged' + ) + + tensor_out = nets['bar_merged'].tensor_out + + # Private bar generator mode + elif config['net_g']['bar_generator_type'] == 'private': + # Tile private latent vector along time axis + if 'private' in self.tensor_in: + tiled_private = [ + tf.reshape( + tf.tile(self.tensor_in['private'][..., idx], (1, 4)), + (-1, 4, self.tensor_in['private'].get_shape()[1]) + ) + for idx in range(config['num_track']) + ] + + # Define private temporal generator + if 'temporal_private' in self.tensor_in: + nets['temporal_private'] = [ + NeuralNet(self.tensor_in['temporal_private'][..., idx], + config['net_g']['temporal_private'], + name='temporal_private_'+str(idx)) + for idx in range(config['num_track']) + ] + + # Get the final input for each bar generator + z_input = [] + for idx in range(config['num_track']): + to_concat = [] + if config['net_g']['z_dim_shared'] > 0: + to_concat.append(tiled_shared) + if config['net_g']['z_dim_private'] > 0: + to_concat.append(tiled_private[idx]) + if config['net_g']['z_dim_temporal_shared'] > 0: + to_concat.append(nets['temporal_shared'].tensor_out) + if config['net_g']['z_dim_temporal_private'] > 0: + to_concat.append(nets['temporal_private'][idx].tensor_out) + z_input.append(tf.concat(to_concat, -1)) + + # Bar generators + nets['bar_main'] = [ + NeuralNet(z_input[idx], config['net_g']['bar_main'], + name='bar_main_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['bar_pitch_time'] = [ + NeuralNet(nets['bar_main'][idx].tensor_out, + config['net_g']['bar_pitch_time'], + name='bar_pitch_time_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['bar_time_pitch'] = [ + NeuralNet(nets['bar_main'][idx].tensor_out, + config['net_g']['bar_time_pitch'], + name='bar_time_pitch_'+str(idx)) + for idx in range(config['num_track']) + ] + + nets['bar_merged'] = [ + NeuralNet( + tf.concat([nets['bar_pitch_time'][idx].tensor_out, + nets['bar_time_pitch'][idx].tensor_out], -1), + config['net_g']['bar_merged'], name='bar_merged_'+str(idx) + ) + for idx in range(config['num_track']) + ] + + tensor_out = tf.concat( + [l.tensor_out for l in nets['bar_merged']], -1) + + return tensor_out, nets + +class Discriminator(Component): + """Class that defines the discriminator.""" + def __init__(self, tensor_in, config, condition=None, name='Discriminator', + reuse=None): + super().__init__(tensor_in, condition) + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.tensor_out, self.nets = self.build(config) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + + def build(self, config): + """Build the discriminator.""" + nets = OrderedDict() + + # Main stream + nets['pitch_time_private'] = [ + NeuralNet(tf.expand_dims(self.tensor_in[..., idx], -1), + config['net_d']['pitch_time_private'], + name='pt_' + str(idx)) + for idx in range(config['num_track']) + ] + + nets['time_pitch_private'] = [ + NeuralNet(tf.expand_dims(self.tensor_in[..., idx], -1), + config['net_d']['time_pitch_private'], + name='tp_' + str(idx)) + for idx in range(config['num_track']) + ] + + nets['merged_private'] = [ + NeuralNet( + tf.concat([x.tensor_out, + nets['time_pitch_private'][idx].tensor_out], -1), + config['net_d']['merged_private'], name='merged_' + str(idx)) + for idx, x in enumerate(nets['pitch_time_private']) + ] + + nets['shared'] = NeuralNet( + tf.concat([l.tensor_out for l in nets['merged_private']], -1), + config['net_d']['shared'], name='shared' + ) + + # Chroma stream + reshaped = tf.reshape( + self.tensor_in, (-1, config['num_bar'], config['num_beat'], + config['beat_resolution'], config['num_pitch']//12, + 12, config['num_track']) + ) + self.chroma = tf.reduce_sum(reshaped, axis=(3, 4)) + nets['chroma'] = NeuralNet(self.chroma, config['net_d']['chroma'], + name='chroma') + + # Onset stream + padded = tf.pad(self.tensor_in[:, :, :-1, :, 1:], + [[0, 0], [0, 0], [1, 0], [0, 0], [0, 0]]) + self.onset = tf.concat([tf.expand_dims(self.tensor_in[..., 0], -1), + self.tensor_in[..., 1:] - padded], -1) + nets['onset'] = NeuralNet(self.onset, config['net_d']['onset'], + name='onset') + + if (config['net_d']['chroma'] is not None + or config['net_d']['onset'] is not None): + to_concat = [nets['shared'].tensor_out] + if config['net_d']['chroma'] is not None: + to_concat.append(nets['chroma'].tensor_out) + if config['net_d']['onset'] is not None: + to_concat.append(nets['onset'].tensor_out) + concated = tf.concat(to_concat, -1) + else: + concated = nets['shared'].tensor_out + + # Merge streams + nets['merged'] = NeuralNet(concated, config['net_d']['merged'], + name='merged') + + return nets['merged'].tensor_out, nets diff --git a/musegan/musegan/models.py b/musegan/musegan/models.py new file mode 100644 index 00000000..c92aad49 --- /dev/null +++ b/musegan/musegan/models.py @@ -0,0 +1,227 @@ +"""Class that defines the GAN model. +""" +import os.path +import time +import numpy as np +import tensorflow as tf +from musegan.model import Model +from musegan.musegan.components import Discriminator, Generator +from musegan.utils.metrics import Metrics + +class GAN(Model): + """Class that defines the first-stage (without refiner) model.""" + def __init__(self, sess, config, name='GAN', reuse=None): + super().__init__(sess, config, name) + + print('[*] Building GAN...') + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.build() + + def build(self): + """Build the model.""" + self.global_step = tf.Variable(0, trainable=False, name='global_step') + + # Create placeholders + self.z = {} + if self.config['net_g']['z_dim_shared'] > 0: + self.z['shared'] = tf.placeholder( + tf.float32, (self.config['batch_size'], + self.config['net_g']['z_dim_shared']), 'z_shared' + ) + if self.config['net_g']['z_dim_private'] > 0: + self.z['private'] = tf.placeholder( + tf.float32, (self.config['batch_size'], + self.config['net_g']['z_dim_private'], + self.config['num_track']), 'z_private' + ) + if self.config['net_g']['z_dim_temporal_shared'] > 0: + self.z['temporal_shared'] = tf.placeholder( + tf.float32, (self.config['batch_size'], + self.config['net_g']['z_dim_temporal_shared']), + 'z_temporal_shared' + ) + if self.config['net_g']['z_dim_temporal_private'] > 0: + self.z['temporal_private'] = tf.placeholder( + tf.float32, (self.config['batch_size'], + self.config['net_g']['z_dim_temporal_private'], + self.config['num_track']), 'z_temporal_private' + ) + + data_shape = (self.config['batch_size'], self.config['num_bar'], + self.config['num_timestep'], self.config['num_pitch'], + self.config['num_track']) + self.x = tf.placeholder(tf.bool, data_shape, 'x') + self.x_ = tf.cast(self.x, tf.float32, 'x_') + + # Components + self.G = Generator(self.z, self.config, name='G') + self.test_round = self.G.tensor_out > 0.5 + self.test_bernoulli = self.G.tensor_out > tf.random_uniform(data_shape) + + self.D_fake = Discriminator(self.G.tensor_out, self.config, name='D') + self.D_real = Discriminator(self.x_, self.config, name='D', reuse=True) + self.components = (self.G, self.D_fake) + + # Losses + self.g_loss, self.d_loss = self.get_adversarial_loss(Discriminator) + + # Optimizers + with tf.variable_scope('Optimizer'): + self.g_optimizer = self.get_optimizer() + self.g_step = self.g_optimizer.minimize( + self.g_loss, self.global_step, self.G.vars) + + self.d_optimizer = self.get_optimizer() + self.d_step = self.d_optimizer.minimize( + self.d_loss, self.global_step, self.D_fake.vars) + + # Apply weight clipping + if self.config['gan']['type'] == 'wgan': + with tf.control_dependencies([self.d_step]): + self.d_step = tf.group( + *(tf.assign(var, tf.clip_by_value( + var, -self.config['gan']['clip_value'], + self.config['gan']['clip_value'])) + for var in self.D_fake.vars)) + + # Metrics + self.metrics = Metrics(self.config) + + # Saver + self.saver = tf.train.Saver() + + # Print and save model information + self.print_statistics() + self.save_statistics() + self.print_summary() + self.save_summary() + + def train(self, x_train, train_config): + """Train the model.""" + # Initialize sampler + self.x_sample = x_train[np.random.choice( + len(x_train), self.config['batch_size'], False)] + feed_dict_sample = {self.x: self.x_sample} + + self.z_sample = {} + for key in self.z: + self.z_sample[key] = np.random.normal(size=self.z[key].get_shape()) + feed_dict_sample[self.z[key]] = self.z_sample[key] + + # Save samples + self.save_samples('x_train', x_train, save_midi=True) + self.save_samples('x_sample', self.x_sample, save_midi=True) + + # Open log files and write headers + log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') + log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') + log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') + log_step.write('# epoch, step, negative_critic_loss\n') + log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') + log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') + + # Initialize counter + counter = 0 + num_batch = len(x_train) // self.config['batch_size'] + + # Start epoch iteration + print('{:=^80}'.format(' Training Start ')) + for epoch in range(train_config['num_epoch']): + + print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) + epoch_start_time = time.time() + + # Prepare batched training data + z_random_batch = {} + for key in self.z: + z_random_batch[key] = np.random.normal( + size=([num_batch] + self.z[key].get_shape().as_list())) + x_random_batch = np.random.choice( + len(x_train), (num_batch, self.config['batch_size']), False) + + # Start batch iteration + for batch in range(num_batch): + + feed_dict_batch = {self.x: x_train[x_random_batch[batch]]} + for key in self.z: + feed_dict_batch[self.z[key]] = z_random_batch[key][batch] + + if (counter < 25) or (counter % 500 == 0): + num_critics = 100 + else: + num_critics = 5 + + batch_start_time = time.time() + + # Update networks + for _ in range(num_critics): + _, d_loss = self.sess.run([self.d_step, self.d_loss], + feed_dict_batch) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + _, d_loss, g_loss = self.sess.run( + [self.g_step, self.d_loss, self.g_loss], feed_dict_batch + ) + log_step.write("{}, {:14.6f}\n".format( + self.get_global_step_str(), -d_loss + )) + + time_batch = time.time() - batch_start_time + + # Print iteration summary + if train_config['verbose']: + if batch < 1: + print("epoch | batch | time | - D_loss |" + " G_loss") + print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " + "{:14.6f}".format(epoch, batch, num_batch, time_batch, + -d_loss, g_loss)) + + log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( + epoch, batch, time_batch, -d_loss, g_loss + )) + + # run sampler + if train_config['sample_along_training']: + if counter%100 == 0 or (counter < 300 and counter%20 == 0): + self.run_sampler(self.G.tensor_out, feed_dict_sample, + False) + self.run_sampler(self.test_round, feed_dict_sample, + (counter > 500), postfix='test_round') + self.run_sampler(self.test_bernoulli, feed_dict_sample, + (counter > 500), + postfix='test_bernoulli') + + # run evaluation + if train_config['evaluate_along_training']: + if counter%10 == 0: + self.run_eval(self.test_round, feed_dict_sample, + postfix='test_round') + self.run_eval(self.test_bernoulli, feed_dict_sample, + postfix='test_bernoulli') + + counter += 1 + + # print epoch info + time_epoch = time.time() - epoch_start_time + + if not train_config['verbose']: + if epoch < 1: + print("epoch | time | - D_loss | G_loss") + print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( + epoch, time_epoch, -d_loss, g_loss)) + + log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( + epoch, time_epoch, -d_loss, g_loss + )) + + # save checkpoints + self.save() + + print('{:=^80}'.format(' Training End ')) + log_step.close() + log_batch.close() + log_epoch.close() diff --git a/musegan/musegan/presets/__init__.py b/musegan/musegan/presets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/musegan/presets/discriminator/ablated.py b/musegan/musegan/presets/discriminator/ablated.py new file mode 100644 index 00000000..e4ab7c28 --- /dev/null +++ b/musegan/musegan/presets/discriminator/ablated.py @@ -0,0 +1,33 @@ +"""Network architecture of the ablated (without onset and chroma streams) +discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (64, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (64, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 7) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = None + +NET_D['chroma'] = None + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/musegan/presets/discriminator/baseline.py b/musegan/musegan/presets/discriminator/baseline.py new file mode 100644 index 00000000..e4d67cb0 --- /dev/null +++ b/musegan/musegan/presets/discriminator/baseline.py @@ -0,0 +1,27 @@ +"""Network architecture of the baseline discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = None + +NET_D['time_pitch_private'] = None + +NET_D['merged_private'] = None + +NET_D['shared'] = None + +NET_D['onset'] = None + +NET_D['chroma'] = None + +NET_D['merged'] = [ + ('conv3d', (128, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (128, (1, 1, 3), (1, 1, 2)), None, 'lrelu'), # 1 (4, 96, 3) + ('conv3d', (256, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 2 (4, 16, 3) + ('conv3d', (256, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 4, 3) + ('conv3d', (512, (1, 1, 3), (1, 1, 3)), None, 'lrelu'), # 4 (4, 4, 1) + ('conv3d', (512, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 5 (4, 1, 1) + ('conv3d', (1024, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 6 (3, 1, 1) + ('reshape', (3*1024)), + ('dense', 1), +] diff --git a/musegan/musegan/presets/discriminator/proposed.py b/musegan/musegan/presets/discriminator/proposed.py new file mode 100644 index 00000000..cfb32b41 --- /dev/null +++ b/musegan/musegan/presets/discriminator/proposed.py @@ -0,0 +1,40 @@ +"""Network architecture of the proposed discriminator +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (64, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (64, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 7) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = [ + ('sum', (3), True), # 0 (4, 96, 1) + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 2 (4, 4, 1) + ('conv3d', (128, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 1, 1) +] + +NET_D['chroma'] = [ + ('conv3d', (64, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 4, 1) + ('conv3d', (128, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/musegan/presets/discriminator/proposed_small.py b/musegan/musegan/presets/discriminator/proposed_small.py new file mode 100644 index 00000000..ba7e76f4 --- /dev/null +++ b/musegan/musegan/presets/discriminator/proposed_small.py @@ -0,0 +1,40 @@ +"""Network architecture of the proposed discriminator (with less filters) +""" +NET_D = {} + +NET_D['pitch_time_private'] = [ + ('conv3d', (16, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 96, 7) + ('conv3d', (32, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['time_pitch_private'] = [ + ('conv3d', (16, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 0 (4, 16, 84) + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 1 (4, 16, 7) +] + +NET_D['merged_private'] = [ + ('conv3d', (32, (1, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (4, 16, 84) +] + +NET_D['shared'] = [ + ('conv3d', (128, (1, 4, 3), (1, 4, 2)), None, 'lrelu'), # 0 (4, 4, 3) + ('conv3d', (256, (1, 4, 3), (1, 4, 3)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['onset'] = [ + ('sum', (3), True), # 0 (4, 96, 1) + ('conv3d', (16, (1, 6, 1), (1, 6, 1)), None, 'lrelu'), # 1 (4, 16, 1) + ('conv3d', (32, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 2 (4, 4, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 3 (4, 1, 1) +] + +NET_D['chroma'] = [ + ('conv3d', (32, (1, 1, 12), (1, 1, 12)), None, 'lrelu'), # 0 (4, 4, 1) + ('conv3d', (64, (1, 4, 1), (1, 4, 1)), None, 'lrelu'), # 1 (4, 1, 1) +] + +NET_D['merged'] = [ + ('conv3d', (512, (2, 1, 1), (1, 1, 1)), None, 'lrelu'), # 0 (3, 1, 1) + ('reshape', (3*512)), + ('dense', 1), +] diff --git a/musegan/musegan/presets/generator/__init__.py b/musegan/musegan/presets/generator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/musegan/musegan/presets/generator/composer.py b/musegan/musegan/presets/generator/composer.py new file mode 100644 index 00000000..66c38aa3 --- /dev/null +++ b/musegan/musegan/presets/generator/composer.py @@ -0,0 +1,48 @@ +"""Network architecture for the generator of the composer model. +""" +NET_G = {} + +# Input latent sizes (NOTE: use 0 instead of None) +NET_G['z_dim_shared'] = 64 +NET_G['z_dim_private'] = 0 +NET_G['z_dim_temporal_shared'] = 64 +NET_G['z_dim_temporal_private'] = 0 +NET_G['z_dim'] = (NET_G['z_dim_shared'] + NET_G['z_dim_private'] + + NET_G['z_dim_temporal_shared'] + + NET_G['z_dim_temporal_private']) + +# Temporal generators +NET_G['temporal_shared'] = [ + ('dense', (3*256), 'bn', 'lrelu'), + ('reshape', (3, 1, 1, 256)), # 1 (3, 1, 1) + ('transconv3d', (NET_G['z_dim_shared'], # 2 (4, 1, 1) + (2, 1, 1), (1, 1, 1)), 'bn', 'lrelu'), + ('reshape', (4, NET_G['z_dim_shared'])), +] + +NET_G['temporal_private'] = None + +# Bar generator +NET_G['bar_generator_type'] = 'shared' + +NET_G['bar_main'] = [ + ('reshape', (4, 1, 1, NET_G['z_dim'])), + ('transconv3d', (1024, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'),# 1 (4, 4, 1) + ('transconv3d', (512, (1, 1, 3), (1, 1, 3)), 'bn', 'lrelu'), # 2 (4, 4, 3) + ('transconv3d', (256, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'), # 3 (4, 16, 3) + ('transconv3d', (128, (1, 1, 3), (1, 1, 2)), 'bn', 'lrelu'), # 4 (4, 16, 7) +] + +NET_G['bar_pitch_time'] = [ + ('transconv3d', (64, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 0 (4, 16, 84) + ('transconv3d', (32, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 1 (4, 96, 84) +] + +NET_G['bar_time_pitch'] = [ + ('transconv3d', (64, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 0 (4, 96, 7) + ('transconv3d', (32, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 1 (4, 96, 84) +] + +NET_G['bar_merged'] = [ + ('transconv3d', [None, (1, 1, 1), (1, 1, 1)], 'bn', 'sigmoid'), +] diff --git a/musegan/musegan/presets/generator/hybrid.py b/musegan/musegan/presets/generator/hybrid.py new file mode 100644 index 00000000..781d5839 --- /dev/null +++ b/musegan/musegan/presets/generator/hybrid.py @@ -0,0 +1,54 @@ +"""Network architecture for the generator of the hybrid model. +""" +NET_G = {} + +# Input latent sizes (NOTE: use 0 instead of None) +NET_G['z_dim_shared'] = 32 +NET_G['z_dim_private'] = 32 +NET_G['z_dim_temporal_shared'] = 32 +NET_G['z_dim_temporal_private'] = 32 +NET_G['z_dim'] = (NET_G['z_dim_shared'] + NET_G['z_dim_private'] + + NET_G['z_dim_temporal_shared'] + + NET_G['z_dim_temporal_private']) + +# Temporal generators +NET_G['temporal_shared'] = [ + ('dense', (3*256), 'bn', 'lrelu'), + ('reshape', (3, 1, 1, 256)), # 1 (3, 1, 1) + ('transconv3d', (NET_G['z_dim_shared'], # 2 (4, 1, 1) + (2, 1, 1), (1, 1, 1)), 'bn', 'lrelu'), + ('reshape', (4, NET_G['z_dim_shared'])), +] + +NET_G['temporal_private'] = [ + ('dense', (3*256), 'bn', 'lrelu'), + ('reshape', (3, 1, 1, 256)), # 1 (3, 1, 1) + ('transconv3d', (NET_G['z_dim_private'], # 2 (4, 1, 1) + (2, 1, 1), (1, 1, 1)), 'bn', 'lrelu'), + ('reshape', (4, NET_G['z_dim_private'])), +] + +# Bar generator +NET_G['bar_generator_type'] = 'private' + +NET_G['bar_main'] = [ + ('reshape', (4, 1, 1, NET_G['z_dim'])), + ('transconv3d', (512, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'), # 1 (4, 4, 1) + ('transconv3d', (256, (1, 1, 3), (1, 1, 3)), 'bn', 'lrelu'), # 2 (4, 4, 3) + ('transconv3d', (128, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'), # 3 (4, 16, 3) + ('transconv3d', (64, (1, 1, 3), (1, 1, 2)), 'bn', 'lrelu'), # 4 (4, 16, 7) +] + +NET_G['bar_pitch_time'] = [ + ('transconv3d', (32, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 0 (4, 16, 84) + ('transconv3d', (16, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 1 (4, 96, 84) +] + +NET_G['bar_time_pitch'] = [ + ('transconv3d', (32, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 0 (4, 96, 7) + ('transconv3d', (16, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 1 (4, 96, 84) +] + +NET_G['bar_merged'] = [ + ('transconv3d', (1, (1, 1, 1), (1, 1, 1)), 'bn', 'sigmoid'), +] diff --git a/musegan/musegan/presets/generator/jamming.py b/musegan/musegan/presets/generator/jamming.py new file mode 100644 index 00000000..19753fe6 --- /dev/null +++ b/musegan/musegan/presets/generator/jamming.py @@ -0,0 +1,48 @@ +"""Network architecture for the generator of the jamming model. +""" +NET_G = {} + +# Input latent sizes (NOTE: use 0 instead of None) +NET_G['z_dim_shared'] = 0 +NET_G['z_dim_private'] = 64 +NET_G['z_dim_temporal_shared'] = 0 +NET_G['z_dim_temporal_private'] = 64 +NET_G['z_dim'] = (NET_G['z_dim_shared'] + NET_G['z_dim_private'] + + NET_G['z_dim_temporal_shared'] + + NET_G['z_dim_temporal_private']) + +# Temporal generators +NET_G['temporal_shared'] = None + +NET_G['temporal_private'] = [ + ('dense', (3*256), 'bn', 'lrelu'), + ('reshape', (3, 1, 1, 256)), # 1 (3, 1, 1) + ('transconv3d', (NET_G['z_dim_private'], # 2 (4, 1, 1) + (2, 1, 1), (1, 1, 1)), 'bn', 'lrelu'), + ('reshape', (4, NET_G['z_dim_private'])), +] + +# Bar generator +NET_G['bar_generator_type'] = 'private' + +NET_G['bar_main'] = [ + ('reshape', (4, 1, 1, NET_G['z_dim'])), + ('transconv3d', (512, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'), # 1 (4, 4, 1) + ('transconv3d', (256, (1, 1, 3), (1, 1, 3)), 'bn', 'lrelu'), # 2 (4, 4, 3) + ('transconv3d', (128, (1, 4, 1), (1, 4, 1)), 'bn', 'lrelu'), # 3 (4, 16, 3) + ('transconv3d', (64, (1, 1, 3), (1, 1, 2)), 'bn', 'lrelu'), # 4 (4, 16, 7) +] + +NET_G['bar_pitch_time'] = [ + ('transconv3d', (32, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 0 (4, 16, 84) + ('transconv3d', (16, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 1 (4, 96, 84) +] + +NET_G['bar_time_pitch'] = [ + ('transconv3d', (32, (1, 6, 1), (1, 6, 1)), 'bn', 'lrelu'), # 0 (4, 96, 7) + ('transconv3d', (16, (1, 1, 12), (1, 1, 12)), 'bn', 'lrelu'),# 1 (4, 96, 84) +] + +NET_G['bar_merged'] = [ + ('transconv3d', (1, (1, 1, 1), (1, 1, 1)), 'bn', 'sigmoid'), +] diff --git a/musegan/utils/__init__.py b/musegan/utils/__init__.py new file mode 100644 index 00000000..0f71644a --- /dev/null +++ b/musegan/utils/__init__.py @@ -0,0 +1,2 @@ +"""Utilities +""" diff --git a/musegan/utils/image_io.py b/musegan/utils/image_io.py new file mode 100644 index 00000000..879759aa --- /dev/null +++ b/musegan/utils/image_io.py @@ -0,0 +1,88 @@ +"""Utilities for creating image grids from a batch of images. +""" +import numpy as np +import imageio + +def get_image_grid(images, shape, grid_width=0, grid_color=0, + frame=False): + """ + Merge the input images and return a merged grid image. + + Arguments + --------- + images : np.array, ndim=3 + The image array. Shape is (num_image, height, width). + shape : list or tuple of int + Shape of the image grid. (height, width) + grid_width : int + Width of the grid lines. Default to 0. + grid_color : int + Color of the grid lines. Available values are 0 (black) to + 255 (white). Default to 0. + frame : bool + True to add frame. Default to False. + + Returns + ------- + merged : np.array, ndim=3 + The merged grid image. + """ + reshaped = images.reshape(shape[0], shape[1], images.shape[1], + images.shape[2]) + pad_width = ((0, 0), (0, 0), (grid_width, 0), (grid_width, 0)) + padded = np.pad(reshaped, pad_width, 'constant', constant_values=grid_color) + transposed = padded.transpose(0, 2, 1, 3) + merged = transposed.reshape(shape[0] * (images.shape[1] + grid_width), + shape[1] * (images.shape[2] + grid_width)) + if frame: + return np.pad(merged, ((0, grid_width), (0, grid_width)), 'constant', + constant_values=grid_color) + return merged[:-grid_width, :-grid_width] + +def save_image(filepath, phrases, shape, inverted=True, grid_width=3, + grid_color=0, frame=True): + """ + Save a batch of phrases to a single image grid. + + Arguments + --------- + filepath : str + Path to save the image grid. + phrases : np.array, ndim=5 + The phrase array. Shape is (num_phrase, num_bar, num_time_step, + num_pitch, num_track). + shape : list or tuple of int + Shape of the image grid. (height, width) + inverted : bool + True to invert the colors. Default to True. + grid_width : int + Width of the grid lines. Default to 3. + grid_color : int + Color of the grid lines. Available values are 0 (black) to + 255 (white). Default to 0. + frame : bool + True to add frame. Default to True. + """ + if phrases.dtype == np.bool_: + if inverted: + phrases = np.logical_not(phrases) + clipped = (phrases * 255).astype(np.uint8) + else: + if inverted: + phrases = 1. - phrases + clipped = (phrases * 255.).clip(0, 255).astype(np.uint8) + + flipped = np.flip(clipped, 3) + transposed = flipped.transpose(0, 4, 1, 3, 2) + reshaped = transposed.reshape(-1, phrases.shape[1] * phrases.shape[4], + phrases.shape[3], phrases.shape[2]) + + merged_phrases = [] + phrase_shape = (phrases.shape[4], phrases.shape[1]) + for phrase in reshaped: + merged_phrases.append(get_image_grid(phrase, phrase_shape, 1, + grid_color)) + + merged = get_image_grid(np.stack(merged_phrases), shape, grid_width, + grid_color, frame) + imageio.imwrite(filepath, merged) diff --git a/musegan/utils/metrics.py b/musegan/utils/metrics.py new file mode 100644 index 00000000..5375418b --- /dev/null +++ b/musegan/utils/metrics.py @@ -0,0 +1,276 @@ +"""Class and utilities for metrics +""" +import os +import warnings +import numpy as np +import matplotlib.pyplot as plt +import SharedArray as sa + +def get_tonal_matrix(r1=1.0, r2=1.0, r3=0.5): + """Compute and return a tonal matrix for computing the tonal distance [1]. + Default argument values are set as suggested by the paper. + + [1] Christopher Harte, Mark Sandler, and Martin Gasser. Detecting harmonic + change in musical audio. In Proc. ACM MM Workshop on Audio and Music + Computing Multimedia, 2006. + """ + tonal_matrix = np.empty((6, 12)) + tonal_matrix[0] = r1 * np.sin(np.arange(12) * (7. / 6.) * np.pi) + tonal_matrix[1] = r1 * np.cos(np.arange(12) * (7. / 6.) * np.pi) + tonal_matrix[2] = r2 * np.sin(np.arange(12) * (3. / 2.) * np.pi) + tonal_matrix[3] = r2 * np.cos(np.arange(12) * (3. / 2.) * np.pi) + tonal_matrix[4] = r3 * np.sin(np.arange(12) * (2. / 3.) * np.pi) + tonal_matrix[5] = r3 * np.cos(np.arange(12) * (2. / 3.) * np.pi) + return tonal_matrix + +def get_num_pitch_used(pianoroll): + """Return the number of unique pitches used in a piano-roll.""" + return np.sum(np.sum(pianoroll, 0) > 0) + +def get_qualified_note_rate(pianoroll, threshold=2): + """Return the ratio of the number of the qualified notes (notes longer than + `threshold` (in time step)) to the total number of notes in a piano-roll.""" + padded = np.pad(pianoroll.astype(int), ((1, 1), (0, 0)), 'constant') + diff = np.diff(padded, axis=0) + flattened = diff.T.reshape(-1,) + onsets = (flattened > 0).nonzero()[0] + offsets = (flattened < 0).nonzero()[0] + num_qualified_note = (offsets - onsets >= threshold).sum() + return num_qualified_note / len(onsets) + +def get_polyphonic_ratio(pianoroll, threshold=2): + """Return the ratio of the number of time steps where the number of pitches + being played is larger than `threshold` to the total number of time steps""" + return np.sum(np.sum(pianoroll, 1) >= threshold) / pianoroll.shape[0] + +def get_in_scale(chroma, scale_mask=None): + """Return the ratio of chroma.""" + measure_chroma = np.sum(chroma, axis=0) + in_scale = np.sum(np.multiply(measure_chroma, scale_mask, dtype=float)) + return in_scale / np.sum(chroma) + +def get_drum_pattern(measure, drum_filter): + """Return the drum_pattern metric value.""" + padded = np.pad(measure, ((1, 0), (0, 0)), 'constant') + measure = np.diff(padded, axis=0) + measure[measure < 0] = 0 + + max_score = 0 + for i in range(6): + cdf = np.roll(drum_filter, i) + score = np.sum(np.multiply(cdf, np.sum(measure, 1))) + if score > max_score: + max_score = score + + return max_score / np.sum(measure) + +def get_harmonicity(bar_chroma1, bar_chroma2, resolution, tonal_matrix=None): + """Return the harmonicity metric value""" + if tonal_matrix is None: + tonal_matrix = get_tonal_matrix() + warnings.warn("`tonal matrix` not specified. Use default tonal matrix", + RuntimeWarning) + score_list = [] + for r in range(bar_chroma1.shape[0]//resolution): + start = r * resolution + end = (r + 1) * resolution + beat_chroma1 = np.sum(bar_chroma1[start:end], 0) + beat_chroma2 = np.sum(bar_chroma2[start:end], 0) + score_list.append(tonal_dist(beat_chroma1, beat_chroma2, tonal_matrix)) + return np.mean(score_list) + +def to_chroma(pianoroll): + """Return the chroma features (not normalized).""" + padded = np.pad(pianoroll, ((0, 0), (0, 12 - pianoroll.shape[1] % 12)), + 'constant') + return np.sum(np.reshape(padded, (pianoroll.shape[0], 12, -1)), 2) + +def tonal_dist(chroma1, chroma2, tonal_matrix=None): + """Return the tonal distance between two chroma features.""" + if tonal_matrix is None: + tonal_matrix = get_tonal_matrix() + warnings.warn("`tonal matrix` not specified. Use default tonal matrix", + RuntimeWarning) + chroma1 = chroma1 / np.sum(chroma1) + result1 = np.matmul(tonal_matrix, chroma1) + chroma2 = chroma2 / np.sum(chroma2) + result2 = np.matmul(tonal_matrix, chroma2) + return np.linalg.norm(result1 - result2) + +def plot_histogram(hist, fig_dir=None, title=None, max_hist_num=None): + """Plot the histograms of the statistics""" + hist = hist[~np.isnan(hist)] + u_value = np.unique(hist) + + hist_num = len(u_value) + if max_hist_num is not None: + if len(u_value) > max_hist_num: + hist_num = max_hist_num + + fig = plt.figure() + plt.hist(hist, hist_num) + if title is not None: + plt.title(title) + if fig_dir is not None and title is not None: + fig.savefig(os.path.join(fig_dir, title)) + plt.close(fig) + +class Metrics(object): + """Class for metrics. + """ + def __init__(self, config): + self.metric_map = config['metric_map'] + self.tonal_distance_pairs = config['tonal_distance_pairs'] + self.track_names = config['track_names'] + self.beat_resolution = config['beat_resolution'] + self.drum_filter = config['drum_filter'] + self.scale_mask = config['scale_mask'] + self.tonal_matrix = get_tonal_matrix( + config['tonal_matrix_coefficient'][0], + config['tonal_matrix_coefficient'][1], + config['tonal_matrix_coefficient'][2] + ) + + self.metric_names = [ + 'empty_bar', + 'pitch_used', + 'qualified_note', + 'polyphonicity', + 'in_scale', + 'drum_pattern', + 'chroma_used', + ] + + def print_metrics_mat(self, metrics_mat): + """Print the intratrack metrics as a nice formatting table""" + print(' ' * 12, ' '.join(['{:^14}'.format(metric_name) + for metric_name in self.metric_names])) + + for t, track_name in enumerate(self.track_names): + value_str = [] + for m in range(len(self.metric_names)): + if np.isnan(metrics_mat[m, t]): + value_str.append('{:14}'.format('')) + else: + value_str.append('{:^14}'.format('{:6.4f}'.format( + metrics_mat[m, t]))) + + print('{:12}'.format(track_name), ' '.join(value_str)) + + def print_metrics_pair(self, pair_matrix): + """Print the intertrack metrics as a nice formatting table""" + for idx, pair in enumerate(self.tonal_distance_pairs): + print("{:12} {:12} {:12.5f}".format( + self.track_names[pair[0]], self.track_names[pair[1]], + pair_matrix[idx])) + + def eval(self, bars, verbose=False, mat_path=None, fig_dir=None): + """Evaluate the input bars with the metrics""" + score_matrix = np.empty((len(self.metric_names), len(self.track_names), + bars.shape[0])) + score_matrix.fill(np.nan) + score_pair_matrix = np.zeros((len(self.tonal_distance_pairs), + bars.shape[0])) + score_pair_matrix.fill(np.nan) + + for b in range(bars.shape[0]): + for t in range(len(self.track_names)): + is_empty_bar = ~np.any(bars[b, ..., t]) + if self.metric_map[0, t]: + score_matrix[0, t, b] = is_empty_bar + if is_empty_bar: + continue + if self.metric_map[1, t]: + score_matrix[1, t, b] = get_num_pitch_used(bars[b, ..., t]) + if self.metric_map[2, t]: + score_matrix[2, t, b] = get_qualified_note_rate( + bars[b, ..., t]) + if self.metric_map[3, t]: + score_matrix[3, t, b] = get_polyphonic_ratio( + bars[b, ..., t]) + if self.metric_map[4, t]: + score_matrix[4, t, b] = get_in_scale( + to_chroma(bars[b, ..., t]), self.scale_mask) + if self.metric_map[5, t]: + score_matrix[5, t, b] = get_drum_pattern(bars[b, ..., t], + self.drum_filter) + if self.metric_map[6, t]: + score_matrix[6, t, b] = get_num_pitch_used( + to_chroma(bars[b, ..., t])) + + for p, pair in enumerate(self.tonal_distance_pairs): + score_pair_matrix[p, b] = get_harmonicity( + to_chroma(bars[b, ..., pair[0]]), + to_chroma(bars[b, ..., pair[1]]), self.beat_resolution, + self.tonal_matrix) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + score_matrix_mean = np.nanmean(score_matrix, axis=2) + score_pair_matrix_mean = np.nanmean(score_pair_matrix, axis=1) + + if verbose: + print("{:=^120}".format(' Evaluation ')) + print('Data Size:', bars.shape) + print("{:-^120}".format('Intratrack Evaluation')) + self.print_metrics_mat(score_matrix_mean) + print("{:-^120}".format('Intertrack Evaluation')) + self.print_metrics_pair(score_pair_matrix_mean) + + if fig_dir is not None: + if not os.path.exists(fig_dir): + os.makedirs(fig_dir) + if verbose: + print('[*] Plotting...') + for m, metric_name in enumerate(self.metric_names): + for t, track_name in enumerate(self.track_names): + if self.metric_map[m, t]: + temp = '-'.join(track_name.replace('.', ' ').split()) + title = '_'.join([metric_name, temp]) + plot_histogram(score_matrix[m, t], fig_dir=fig_dir, + title=title, max_hist_num=20) + if verbose: + print("Successfully saved to", fig_dir) + + if mat_path is not None: + if not mat_path.endswith(".npy"): + mat_path = mat_path + '.npy' + info_dict = { + 'score_matrix_mean': score_matrix_mean, + 'score_pair_matrix_mean': score_pair_matrix_mean} + if verbose: + print('[*] Saving score matrices...') + np.save(mat_path, info_dict) + if verbose: + print("Successfully saved to", mat_path) + + return score_matrix_mean, score_pair_matrix_mean + +def eval_dataset(filepath, result_dir, location, config): + """Run evaluation on a dataset stored in either shared array (if `location` + is 'sa') or in hard disk (if `location` is 'hd') and save the results to the + given directory. + + """ + print('[*] Loading dataset...') + if location == 'sa': + data = sa.attach(filepath) + elif location == 'hd': + data = sa.attach(filepath) + else: + raise ValueError("Unrecognized value for `location`") + + print('[*] Running evaluation') + data = data.reshape(-1, config['num_timestep'], config['num_pitch'], + config['num_track']) + metrics = Metrics(config) + _ = metrics.eval(data, verbose=True, + mat_path=os.path.join(result_dir, 'score_matrices.npy'), + fig_dir=result_dir) + +def print_mat_file(mat_path, config): + """Print the score matrices stored in a file.""" + metrics = Metrics(config) + with np.load(mat_path) as loaded: + metrics.print_metrics_mat(loaded['score_matrix_mean']) + metrics.print_metrics_pair(loaded['score_pair_matrix_mean']) diff --git a/musegan/utils/midi_io.py b/musegan/utils/midi_io.py new file mode 100644 index 00000000..2d9ef9a5 --- /dev/null +++ b/musegan/utils/midi_io.py @@ -0,0 +1,84 @@ +"""Utilities for writing piano-rolls to MIDI files. +""" +import numpy as np +from pypianoroll import Multitrack, Track + +def write_midi(filepath, pianorolls, program_nums=None, is_drums=None, + track_names=None, velocity=100, tempo=120.0, beat_resolution=24): + """ + Write the given piano-roll(s) to a single MIDI file. + + Arguments + --------- + filepath : str + Path to save the MIDI file. + pianorolls : np.array, ndim=3 + The piano-roll array to be written to the MIDI file. Shape is + (num_timestep, num_pitch, num_track). + program_nums : int or list of int + MIDI program number(s) to be assigned to the MIDI track(s). Available + values are 0 to 127. Must have the same length as `pianorolls`. + is_drums : list of bool + Drum indicator(s) to be assigned to the MIDI track(s). True for + drums. False for other instruments. Must have the same length as + `pianorolls`. + track_names : list of str + Track name(s) to be assigned to the MIDI track(s). + """ + if not np.issubdtype(pianorolls.dtype, np.bool_): + raise TypeError("Support only binary-valued piano-rolls") + if isinstance(program_nums, int): + program_nums = [program_nums] + if isinstance(is_drums, int): + is_drums = [is_drums] + + if pianorolls.shape[2] != len(program_nums): + raise ValueError("`pianorolls` and `program_nums` must have the same" + "length") + if pianorolls.shape[2] != len(is_drums): + raise ValueError("`pianorolls` and `is_drums` must have the same" + "length") + if program_nums is None: + program_nums = [0] * len(pianorolls) + if is_drums is None: + is_drums = [False] * len(pianorolls) + + multitrack = Multitrack(beat_resolution=beat_resolution, tempo=tempo) + for idx in range(pianorolls.shape[2]): + if track_names is None: + track = Track(pianorolls[..., idx], program_nums[idx], + is_drums[idx]) + else: + track = Track(pianorolls[..., idx], program_nums[idx], + is_drums[idx], track_names[idx]) + multitrack.append_track(track) + multitrack.write(filepath) + +def save_midi(filepath, phrases, config): + """ + Save a batch of phrases to a single MIDI file. + + Arguments + --------- + filepath : str + Path to save the image grid. + phrases : list of np.array + Phrase arrays to be saved. All arrays must have the same shape. + pause : int + Length of pauses (in timestep) to be inserted between phrases. + Default to 0. + """ + if not np.issubdtype(phrases.dtype, np.bool_): + raise TypeError("Support only binary-valued piano-rolls") + + reshaped = phrases.reshape(-1, phrases.shape[1] * phrases.shape[2], + phrases.shape[3], phrases.shape[4]) + pad_width = ((0, 0), (0, config['pause_between_samples']), + (config['lowest_pitch'], + 128 - config['lowest_pitch'] - config['num_pitch']), + (0, 0)) + padded = np.pad(reshaped, pad_width, 'constant') + pianorolls = padded.reshape(-1, padded.shape[2], padded.shape[3]) + + write_midi(filepath, pianorolls, config['programs'], config['is_drums'], + tempo=config['tempo']) diff --git a/musegan/utils/neuralnet.py b/musegan/utils/neuralnet.py new file mode 100644 index 00000000..dc264ac1 --- /dev/null +++ b/musegan/utils/neuralnet.py @@ -0,0 +1,255 @@ +"""Classes for neural networks and layers. +""" +import numpy as np +import tensorflow as tf +from musegan.utils.ops import binary_stochastic_ST + +SUPPORTED_LAYER_TYPES = ( + 'reshape', 'mean', 'sum', 'dense', 'identity', 'conv1d', 'conv2d', 'conv3d', + 'transconv2d', 'transconv3d', 'avgpool2d', 'avgpool3d', 'maxpool2d', + 'maxpool3d' +) + +class Layer(object): + """Base class for layers.""" + def __init__(self, tensor_in, structure=None, condition=None, + slope_tensor=None, name=None, reuse=None): + if not isinstance(tensor_in, tf.Tensor): + raise TypeError("`tensor_in` must be of tf.Tensor type") + + self.tensor_in = tensor_in + + if structure is not None: + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + if structure[0] not in SUPPORTED_LAYER_TYPES: + raise ValueError("Unknown layer type at " + self.scope.name) + self.layer_type = structure[0] + self.tensor_out = self.build(structure, condition, slope_tensor) + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + else: + self.scope = None + self.layer_type = 'bypass' + self.tensor_out = tensor_in + self.vars = [] + + def __repr__(self): + return "Layer({}, type={}, input_shape={}, output_shape={})".format( + self.scope.name, self.layer_type, self.tensor_in.get_shape(), + self.tensor_out.get_shape()) + + def get_summary(self): + """Return the summary string.""" + return "{:36} {:12} {:30}".format( + self.scope.name, self.layer_type, str(self.tensor_out.get_shape())) + + def build(self, structure, condition, slope_tensor): + """Build the layer.""" + # Mean layers + if self.layer_type == 'mean': + keepdims = structure[2] if len(structure) > 2 else None + return tf.reduce_mean(self.tensor_in, structure[1], keepdims, + name='mean') + + # Summation layers + if self.layer_type == 'sum': + keepdims = structure[2] if len(structure) > 2 else None + return tf.reduce_sum(self.tensor_in, structure[1], keepdims, + name='sum') + + # Reshape layers + if self.layer_type == 'reshape': + if np.prod(structure[1]) != np.prod(self.tensor_in.get_shape()[1:]): + raise ValueError("Bad reshape size: {} to {} at {}".format( + self.tensor_in.get_shape()[1:], structure[1], + self.scope.name)) + if isinstance(structure[1], int): + reshape_shape = (-1, structure[1]) + else: + reshape_shape = (-1,) + structure[1] + return tf.reshape(self.tensor_in, reshape_shape, 'reshape') + + # Pooling layers + if self.layer_type == 'avgpool2d': + return tf.layers.average_pooling2d(self.tensor_in, structure[1][0], + structure[1][1], + name='avgpool2d') + if self.layer_type == 'maxpool2d': + return tf.layers.max_pooling2d(self.tensor_in, structure[1][0], + structure[1][1], name='maxpool2d') + if self.layer_type == 'avgpool3d': + return tf.layers.average_pooling3d(self.tensor_in, structure[1][0], + structure[1][1], + name='avgpool3d') + if self.layer_type == 'maxpool3d': + return tf.layers.max_pooling3d(self.tensor_in, structure[1][0], + structure[1][1], name='maxpool3d') + + # Condition + if condition is None: + self.conditioned = self.tensor_in + elif self.layer_type == 'dense': + self.conditioned = tf.concat([self.tensor_in, condition], 1) + elif self.layer_type in ('conv1d', 'conv2d', 'transconv2d', 'conv3d', + 'transconv3d'): + if self.layer_type == 'conv1d': + reshape_shape = (-1, 1, condition.get_shape()[1]) + elif self.layer_type in ('conv2d', 'transconv2d'): + reshape_shape = (-1, 1, 1, condition.get_shape()[1]) + else: # ('conv3d', 'transconv3d') + reshape_shape = (-1, 1, 1, 1, condition.get_shape()[1]) + reshaped = tf.reshape(condition, reshape_shape) + out_shape = ([-1] + self.tensor_in.get_shape()[1:-1] + + [condition.get_shape()[1]]) + to_concat = reshaped * tf.ones(out_shape) + self.conditioned = tf.concat([self.tensor_in, to_concat], -1) + + # Core layers (dense, convolutional or identity layer) + if self.layer_type == 'dense': + kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) + self.core = tf.layers.dense(self.conditioned, structure[1], + kernel_initializer=kernel_initializer, + name='dense') + elif self.layer_type == 'identity': + self.core = self.conditioned + else: + filters = structure[1][0] + kernel_size = structure[1][1] + strides = structure[1][2] if len(structure[1]) > 2 else 1 + padding = structure[1][3] if len(structure[1]) > 3 else 'valid' + kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) + + if self.layer_type == 'conv1d': + self.core = tf.layers.conv1d( + self.conditioned, filters, kernel_size, strides, padding, + kernel_initializer=kernel_initializer, name='conv1d') + elif self.layer_type == 'conv2d': + self.core = tf.layers.conv2d( + self.conditioned, filters, kernel_size, strides, padding, + kernel_initializer=kernel_initializer, name='conv2d') + elif self.layer_type == 'transconv2d': + self.core = tf.layers.conv2d_transpose( + self.conditioned, filters, kernel_size, strides, padding, + kernel_initializer=kernel_initializer, name='transconv2d') + elif self.layer_type == 'conv3d': + self.core = tf.layers.conv3d( + self.conditioned, filters, kernel_size, strides, padding, + kernel_initializer=kernel_initializer, name='conv3d') + elif self.layer_type == 'transconv3d': + self.core = tf.layers.conv3d_transpose( + self.conditioned, filters, kernel_size, strides, padding, + kernel_initializer=kernel_initializer, name='transconv3d') + + # normalization layer + if len(structure) > 2: + if structure[2] not in (None, 'bn', 'in', 'ln'): + raise ValueError("Unknown normalization at " + self.scope.name) + normalization = structure[2] + else: + normalization = None + + if normalization is None: + self.normalized = self.core + elif normalization == 'bn': + self.normalized = tf.layers.batch_normalization( + self.core, name='batch_norm') + elif normalization == 'in': + self.normalized = tf.contrib.layers.instance_norm( + self.core, scope='instance_norm') + elif normalization == 'ln': + self.normalized = tf.contrib.layers.layer_norm( + self.core, scope='layer_norm') + + # activation + if len(structure) > 3: + if structure[3] not in (None, 'tanh', 'sigmoid', 'relu', 'lrelu', + 'bernoulli', 'round'): + raise ValueError("Unknown activation at " + self.scope.name) + activation = structure[3] + else: + activation = None + + if activation is None: + self.activated = self.normalized + elif activation == 'tanh': + self.activated = tf.nn.tanh(self.normalized, 'tanh') + elif activation == 'sigmoid': + self.activated = tf.nn.sigmoid(self.normalized, 'sigmoid') + elif activation == 'relu': + self.activated = tf.nn.relu(self.normalized, 'relu') + elif activation == 'lrelu': + self.activated = tf.nn.leaky_relu(self.normalized, name='lrelu') + elif activation == 'bernoulli': + self.activated, self.preactivated = binary_stochastic_ST( + self.normalized, slope_tensor, False, True) + elif activation == 'round': + self.activated, self.preactivated = binary_stochastic_ST( + self.normalized, slope_tensor, False, False) + + return self.activated + +class NeuralNet(object): + """Base class for neural networks.""" + def __init__(self, tensor_in, architecture=None, condition=None, + slope_tensor=None, name='NeuralNet', reuse=None): + if not isinstance(tensor_in, tf.Tensor): + raise TypeError("`tensor_in` must be of tf.Tensor type") + + self.tensor_in = tensor_in + self.condition = condition + self.slope_tensor = slope_tensor + + if architecture is not None: + with tf.variable_scope(name, reuse=reuse) as scope: + self.scope = scope + self.layers = self.build(architecture) + self.tensor_out = self.layers[-1].tensor_out + self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + self.scope.name) + else: + self.scope = None + self.layers = [] + self.tensor_out = tensor_in + self.vars = [] + + def __repr__(self): + return "NeuralNet({}, input_shape={}, output_shape={})".format( + self.scope.name, self.tensor_in.get_shape(), + self.tensor_out.get_shape()) + + def get_summary(self): + """Return the summary string.""" + return '\n'.join( + ['[{}]'.format(self.scope.name), + "{:49} {}".format('Input', self.tensor_in.get_shape())] + + [x.get_summary() for x in self.layers]) + + def build(self, architecture): + """Build the neural network.""" + layers = [] + for idx, structure in enumerate(architecture): + if idx > 0: + prev_layer = layers[idx-1].tensor_out + else: + prev_layer = self.tensor_in + + # Skip connections + if len(structure) > 4: + skip_connection = structure[4][0] + else: + skip_connection = None + + if skip_connection is None: + connected = prev_layer + elif skip_connection == 'add': + connected = prev_layer + layers[structure[4][1]].tensor_out + elif skip_connection == 'concat': + connected = tf.concat( + [prev_layer, layers[structure[4][1]].tensor_out], -1) + + # Build layer + layers.append(Layer(connected, structure, + slope_tensor=self.slope_tensor, + name='Layer_{}'.format(idx))) + return layers diff --git a/musegan/utils/ops.py b/musegan/utils/ops.py new file mode 100644 index 00000000..c2262989 --- /dev/null +++ b/musegan/utils/ops.py @@ -0,0 +1,187 @@ +"""Operations for implementing binary neurons. Code is from the R2RT blog post: +https://r2rt.com/binary-stochastic-neurons-in-tensorflow.html (slightly adapted) +""" +import tensorflow as tf +from tensorflow.python.framework import ops + +def binary_round(x): + """ + Rounds a tensor whose values are in [0,1] to a tensor with values in + {0, 1}, using the straight through estimator for the gradient. + """ + g = tf.get_default_graph() + + with ops.name_scope("BinaryRound") as name: + with g.gradient_override_map({"Round": "Identity"}): + return tf.round(x, name=name) + +def bernoulli_sample(x): + """ + Uses a tensor whose values are in [0,1] to sample a tensor with values + in {0, 1}, using the straight through estimator for the gradient. + + E.g., if x is 0.6, bernoulliSample(x) will be 1 with probability 0.6, + and 0 otherwise, and the gradient will be pass-through (identity). + """ + g = tf.get_default_graph() + + with ops.name_scope("BernoulliSample") as name: + with g.gradient_override_map({"Ceil": "Identity", + "Sub": "BernoulliSample_ST"}): + return tf.ceil(x - tf.random_uniform(tf.shape(x)), name=name) + +@ops.RegisterGradient("BernoulliSample_ST") +def bernoulli_sample_ST(op, grad): + return [grad, tf.zeros(tf.shape(op.inputs[1]))] + +def pass_through_sigmoid(x, slope=1): + """Sigmoid that uses identity function as its gradient""" + g = tf.get_default_graph() + with ops.name_scope("PassThroughSigmoid") as name: + with g.gradient_override_map({"Sigmoid": "Identity"}): + return tf.sigmoid(x, name=name) + +def binary_stochastic_ST(x, slope_tensor=None, pass_through=True, + stochastic=True): + """ + Sigmoid followed by either a random sample from a bernoulli distribution + according to the result (binary stochastic neuron) (default), or a + sigmoid followed by a binary step function (if stochastic == False). + Uses the straight through estimator. See + https://arxiv.org/abs/1308.3432. + + Arguments: + * x: the pre-activation / logit tensor + * slope_tensor: if passThrough==False, slope adjusts the slope of the + sigmoid function for purposes of the Slope Annealing Trick (see + http://arxiv.org/abs/1609.01704) + * pass_through: if True (default), gradient of the entire function is 1 + or 0; if False, gradient of 1 is scaled by the gradient of the + sigmoid (required if Slope Annealing Trick is used) + * stochastic: binary stochastic neuron if True (default), or step + function if False + """ + if slope_tensor is None: + slope_tensor = tf.constant(1.0) + + if pass_through: + p = pass_through_sigmoid(x) + else: + p = tf.sigmoid(slope_tensor * x) + + if stochastic: + return bernoulli_sample(p), p + else: + return binary_round(p), p + +def binary_stochastic_REINFORCE(x, loss_op_name="loss_by_example"): + """ + Sigmoid followed by a random sample from a bernoulli distribution + according to the result (binary stochastic neuron). Uses the REINFORCE + estimator. See https://arxiv.org/abs/1308.3432. + + NOTE: Requires a loss operation with name matching the argument for + loss_op_name in the graph. This loss operation should be broken out by + example (i.e., not a single number for the entire batch). + """ + g = tf.get_default_graph() + + with ops.name_scope("BinaryStochasticREINFORCE"): + with g.gradient_override_map({"Sigmoid": "BinaryStochastic_REINFORCE", + "Ceil": "Identity"}): + p = tf.sigmoid(x) + + reinforce_collection = g.get_collection("REINFORCE") + if not reinforce_collection: + g.add_to_collection("REINFORCE", {}) + reinforce_collection = g.get_collection("REINFORCE") + reinforce_collection[0][p.op.name] = loss_op_name + + return tf.ceil(p - tf.random_uniform(tf.shape(x))) + + +@ops.RegisterGradient("BinaryStochastic_REINFORCE") +def _binaryStochastic_REINFORCE(op, _): + """Unbiased estimator for binary stochastic function based on REINFORCE.""" + loss_op_name = op.graph.get_collection("REINFORCE")[0][op.name] + loss_tensor = op.graph.get_operation_by_name(loss_op_name).outputs[0] + + sub_tensor = op.outputs[0].consumers()[0].outputs[0] #subtraction tensor + ceil_tensor = sub_tensor.consumers()[0].outputs[0] #ceiling tensor + + outcome_diff = (ceil_tensor - op.outputs[0]) + + # Provides an early out if we want to avoid variance adjustment for + # whatever reason (e.g., to show that variance adjustment helps) + if op.graph.get_collection("REINFORCE")[0].get("no_variance_adj"): + return outcome_diff * tf.expand_dims(loss_tensor, 1) + + outcome_diff_sq = tf.square(outcome_diff) + outcome_diff_sq_r = tf.reduce_mean(outcome_diff_sq, reduction_indices=0) + outcome_diff_sq_loss_r = tf.reduce_mean( + outcome_diff_sq * tf.expand_dims(loss_tensor, 1), reduction_indices=0) + + l_bar_num = tf.Variable(tf.zeros(outcome_diff_sq_r.get_shape()), + trainable=False) + l_bar_den = tf.Variable(tf.ones(outcome_diff_sq_r.get_shape()), + trainable=False) + + # Note: we already get a decent estimate of the average from the minibatch + decay = 0.95 + train_l_bar_num = tf.assign(l_bar_num, l_bar_num*decay +\ + outcome_diff_sq_loss_r*(1-decay)) + train_l_bar_den = tf.assign(l_bar_den, l_bar_den*decay +\ + outcome_diff_sq_r*(1-decay)) + + + with tf.control_dependencies([train_l_bar_num, train_l_bar_den]): + l_bar = train_l_bar_num/(train_l_bar_den + 1e-4) + l = tf.tile(tf.expand_dims(loss_tensor, 1), + tf.constant([1, l_bar.get_shape().as_list()[0]])) + return outcome_diff * (l - l_bar) + +def binary_wrapper(pre_activations_tensor, estimator, + stochastic_tensor=tf.constant(True), pass_through=True, + slope_tensor=tf.constant(1.0)): + """ + Turns a layer of pre-activations (logits) into a layer of binary + stochastic neurons + + Keyword arguments: + *estimator: either ST or REINFORCE + *stochastic_tensor: a boolean tensor indicating whether to sample from a + bernoulli distribution (True, default) or use a step_function (e.g., + for inference) + *pass_through: for ST only - boolean as to whether to substitute + identity derivative on the backprop (True, default), or whether to + use the derivative of the sigmoid + *slope_tensor: for ST only - tensor specifying the slope for purposes of + slope annealing trick + """ + if estimator == 'straight_through': + if pass_through: + return tf.cond( + stochastic_tensor, + lambda: binary_stochastic_ST(pre_activations_tensor), + lambda: binary_stochastic_ST(pre_activations_tensor, + stochastic=False)) + else: + return tf.cond( + stochastic_tensor, + lambda: binary_stochastic_ST(pre_activations_tensor, + slope_tensor, False), + lambda: binary_stochastic_ST(pre_activations_tensor, + slope_tensor, False, False)) + + elif estimator == 'reinforce': + # binaryStochastic_REINFORCE was designed to only be stochastic, so + # using the ST version for the step fn for purposes of using step + # fn at evaluation / not for training + return tf.cond( + stochastic_tensor, + lambda: binary_stochastic_REINFORCE(pre_activations_tensor), + lambda: binary_stochastic_ST(pre_activations_tensor, + stochastic=False)) + + else: + raise ValueError("Unrecognized estimator.") diff --git a/pretrained/README.md b/pretrained/README.md new file mode 100644 index 00000000..72ab7629 --- /dev/null +++ b/pretrained/README.md @@ -0,0 +1,30 @@ +# Shell scripts for downloading pretrained model + +## Download a particular pretrained model + +Run + +```sh +sh download.sh [model] [filename] +``` + +This will download a particular pretrained model to the current working +directory. Available files are listed as follows: + +- MuseGAN (`musegan`) + - `lastfm_alternative_g_hybrid_d_proposed.tar.gz` + +- BinaryMuseGAN (`bmusegan`) + - `lastfm_alternative_first_stage_d_proposed.tar.gz` + - `lastfm_alternative_first_stage_d_abalted.tar.gz` + - `lastfm_alternative_first_stage_d_baseline.tar.gz` + +## Download all pretrained models + +Run + +```sh +sh download_all.sh +``` + +This will download all pretrained models to the current working directory. diff --git a/pretrained/download.sh b/pretrained/download.sh new file mode 100644 index 00000000..834a4457 --- /dev/null +++ b/pretrained/download.sh @@ -0,0 +1,45 @@ +#!/bin/bash +case $2 in + *.tar.gz) + filename=$2 + ;; + *) + filename=$2.tar.gz + ;; +esac + +case $1 in + "musegan"|"MuseGAN") + case $filename in + "lastfm_alternative_g_hybrid_d_proposed.tar.gz") + fileid=1b1bwTzW09QPFbRn2Hy9X8yU1fbTc3S1k + ;; + *) + echo "File not found" + exit 1 + ;; + esac + "bmusegan"|"binarymusegan"|"BinaryMuseGAN") + case $filename in + "lastfm_alternative_first_stage_d_proposed.tar.gz") + fileid=12tEzs-Qa-qi59hLJB8TlD-vcZgVEQZu6 + ;; + "lastfm_alternative_first_stage_d_ablated.tar.gz") + fileid=1GolkoE2ktmHF2Pt7POd8TBBYZARu6ih8 + ;; + "lastfm_alternative_first_stage_d_baseline.tar.gz") + fileid=1qWWWU6UTMJvzdK6y4bvh3PRXF5Xbk09v + ;; + *) + echo "File not found" + exit 1 + ;; + esac + *) + echo "Unrecognizeable model name" + exit 1 + ;; +esac + +wget -O $filename --no-check-certificate \ + "https://drive.google.com/uc?export=download&id="$fileid diff --git a/pretrained/download_all.sh b/pretrained/download_all.sh new file mode 100644 index 00000000..a289a1ee --- /dev/null +++ b/pretrained/download_all.sh @@ -0,0 +1,7 @@ +#!/bin/bash +sh download.sh musegan lastfm_alternative_g_hybrid_d_proposed + +for postfix in proposed ablated baseline +do + sh download.sh bmusegan lastfm_alternative_first_stage_d_$postfix +done diff --git a/training_data/README.md b/training_data/README.md new file mode 100644 index 00000000..ebc59713 --- /dev/null +++ b/training_data/README.md @@ -0,0 +1,31 @@ +# Shell scripts for downloading training data + +## Download particular training data + +Run + +```sh +sh download.sh [filename] +``` + +This will download the training data to the current working directory. Available +training data are listed as follows: + +- `lastfm_alternative_5b_phrase.npy` contains 12,444 four-bar phrases from 2,074 + songs with *alternative* tags. The shape is (2074, 6, 4, 96, 84, 5). The five + tracks are *Drums*, *Piano*, *Guitar*, *Bass* and *Strings*. + +- `lastfm_alternative_8b_phrase.npy` contains 13,746 four-bar phrases from 2,291 + songs with *alternative* tags. The shape is (2291, 6, 4, 96, 84, 8). The + eight tracks are *Drums*, *Piano*, *Guitar*, *Bass*, *Ensemble*, *Reed*, + *Synth Lead* and *Synth Pad*. + +## Download all training data + +Run + +```sh +sh download_all.sh +``` + +This will download all training data to the current working directory. diff --git a/training_data/download.sh b/training_data/download.sh new file mode 100644 index 00000000..b12805c0 --- /dev/null +++ b/training_data/download.sh @@ -0,0 +1,25 @@ +#!/bin/bash +case $1 in + *.npy) + filename=$1 + ;; + *) + filename=$1.npy + ;; +esac + +case $filename in + "lastfm_alternative_5b_phrase.npy") + fileid=1F7J5n9uOPqViBYpoPT5GvE4PjCWhOyWc + ;; + "lastfm_alternative_8b_phrase.npy") + fileid=1x3CeSqE6ElWa6V7ueNl8FKPFmMoyu4ED + ;; + *) + echo "File not found" + exit 1 + ;; +esac + +wget -O $filename --no-check-certificate \ + "https://drive.google.com/uc?export=download&id="$fileid diff --git a/training_data/download_all.sh b/training_data/download_all.sh new file mode 100644 index 00000000..23b2753d --- /dev/null +++ b/training_data/download_all.sh @@ -0,0 +1,3 @@ +#!/bin/bash +sh download.sh lastfm_alternative_5b_phrase.npy +sh download.sh lastfm_alternative_8b_phrase.npy diff --git a/training_data/store_to_sa.py b/training_data/store_to_sa.py new file mode 100644 index 00000000..f38f790f --- /dev/null +++ b/training_data/store_to_sa.py @@ -0,0 +1,37 @@ +"""Store a numpy array to shared memory via SharedArray package. +""" +import os.path +import argparse +import numpy as np +import SharedArray as sa + +def parse_arguments(): + """Parse and return the command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument('filepath', help="Path to the data file.") + parser.add_argument('--name', help="File name to save in SharedArray. " + "Default to use the original file name.") + parser.add_argument('--prefix', help="Prefix to the file name to save in " + "SharedArray. Only effective when " + "`name` is not given.") + args = parser.parse_args() + return args.filepath, args.name, args.prefix + +def main(): + """Main function""" + filepath, name, prefix = parse_arguments() + + data = np.load(filepath) + + if name is None: + name = os.path.splitext(os.path.basename(filepath))[0] + if prefix is not None: + name = prefix + '_' + name + + sa_array = sa.create(name, data.shape, data.dtype) + np.copyto(sa_array, data) + + print("Successfully saved: {}, {}, {}".format(name, data.shape, data.dtype)) + +if __name__ == '__main__': + main()