Skip to content

Commit

Permalink
Merging tensorboard branch
Browse files Browse the repository at this point in the history
Tensorboard
  • Loading branch information
gwinndr authored Mar 24, 2020
2 parents 6fc9a82 + a307429 commit a3b3c90
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 67 deletions.
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,40 @@ In order to play .mid files, we used [Midi Editor](https://www.midieditor.org/)
* Fixed length song generation
* Midi augmentations from paper
* Multi-GPU support
* Experiment with tensorboard for result reporting

## How to run
You will firstly need to download the Maestro dataset (we used v2 but v1 should work as well). You can download the dataset [here](https://magenta.tensorflow.org/datasets/maestro) (you only need the midi version if you're tight on space). We use the midi pre-processor provided by jason9693 et al. (https://github.com/jason9693/midi-neural-processor) to convert the midi into discrete ordered message types for training and evaluating.

First run third_party/get_code.sh to download the midi pre-processor from github. If on Windows, look at the code and you'll see what to do (it's very simple :D). After, run preprocess_midi.py with --help for details. The result will be a pre-processed folder with a train, val, and test split as provided by Maestro's recommendation.
First run get_code.sh in third_party to download the midi pre-processor from github. If on Windows, look at the code and you'll see what to do (it's very simple :D). After, run preprocess_midi.py with --help for details. The result will be a pre-processed folder with a train, val, and test split as provided by Maestro's recommendation.

To train a model, run train.py. Use --help to see the tweakable parameters. See the results section for details on model performance. After training models, you can evaluate them with evaluate.py and generate a midi piece with generate.py. To graph and compare results visually, use graph_results.py.

For the most part, you can just leave most arguments at their default values. If you are using a different dataset location or other such things, you will need to specify that in the arguments. Beyond that, the average user does not have to worry about most of the arguments.

### Training
As an example to train a model using the parameters specified in results:

```
python train.py -output_dir rpr --rpr
```
You can additonally specify both a weight and print modulus that determine what epochs to save weights and what batches to print.

### Evaluation
You can evaluate a model using;
```
python evaluate.py -model_weights rpr/results/best_acc_weights.pickle --rpr
```

Your model's results may vary because a random sequence start position is chosen for each evaluation piece. This may be changed in the future.

### Generation
You can generate a piece with a trained model by using:
```
python generate.py -output_dir output -model_weights rpr/results/best_acc_weights.pickle --rpr
```

The default generation method is a probability distribution over the softmaxed output. You can also use beam search but this simply does not work well and is not recommended.

## Pytorch Transformer
We used the Transformer class provided since Pytorch 1.2.0 (https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer). The provided Transformer assumes an encoder-decoder architecture. To make it decoder-only like the Music Transformer, you use stacked encoders with a custom dummy decoder. This decoder-only model can be found in model/music_transformer.py.

Expand All @@ -45,10 +70,19 @@ We trained a base and RPR model with the following parameters (taken from the pa
* **dim_feedforward**: 1024
* **dropout**: 0.1

The following graphs were generated with the command:
```
python graph_results.py -input_dirs base_model/results?rpr_model/results -model_names base?rpr
```

Note, multiple input models are separated with a '?'

![Loss Results Graph](https://lh3.googleusercontent.com/u6AL9vIXG7gBeKuLlVJGFeex7-q2NYLbMqYVZGFI3qxWlpa6hAXdVlOsD52i4jKjrVcf4YZCGBaMIVIagcu_z-7Sg5YhDcgsqcs-p4aR48C287c1QraG0tRnHnmimLd8jizk9afW8g=w2400 "Loss Results")

![Accuracy Results Graph](https://lh3.googleusercontent.com/ajbanROlOAM9YrNDaHrv1tWM8tZ4nrcrTehwoHsaftnPPZ4xEBLG0RmBa4awYXntBQF0RR_Uh3bsLZv4mdzmZM_TNisMnreKsB2jZIY7iSZjQiL4kRumypymuxIiHu-VdPB0kUkILQ=w2400 "Accuracy Results")

![Learn Rate Results Graph](https://lh3.googleusercontent.com/Gz8N8tgHN2qstvdq77GqQQiukWjwBUettMK8IYV0228il5NvRdrnoISS5HTrxd7xVOrRpSzTtLlRppT-UwWJ2ke1XnAsRMbJ0bCElSvCQAA_z08HSZjbJ4wQXBbg4lVzuGdikEN5Ug=w2400 "Learn Rate Results")

Best loss for *base* model: 1.99 on epoch 250
Best loss for *rpr* model: 1.92 on epoch 216

Expand Down
157 changes: 151 additions & 6 deletions graph_results.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,163 @@
import argparse
import os
import csv
import math
import matplotlib.pyplot as plt

RESULTS_FILE = "results.csv"
EPOCH_IDX = 0
LR_IDX = 1
EVAL_LOSS_IDX = 4
EVAL_ACC_IDX = 5

SPLITTER = '?'

# graph_results
def graph_results(input_dirs="./saved_models", output_dir=None, model_names=None, epoch_start=0, epoch_end=None):
def graph_results(input_dirs="./saved_models/results", output_dir=None, model_names=None, epoch_start=0, epoch_end=None):
"""
----------
Author: Damon Gwinn
----------
Graphs model training and evaluation data
----------
"""

input_dirs = input_dirs.split(SPLITTER)

if(model_names is not None):
model_names = model_names.split(SPLITTER)
if(len(model_names) != len(input_dirs)):
print("Error: len(model_names) != len(input_dirs)")
return

#Initialize Loss and Accuracy arrays
loss_arrs = []
accuracy_arrs = []
epoch_counts = []
lrs = []

for input_dir in input_dirs:
loss_arr = []
accuracy_arr = []
epoch_count = []
lr_arr = []

f = os.path.join(input_dir, RESULTS_FILE)
with open(f, "r") as i_stream:
reader = csv.reader(i_stream)
next(reader)

lines = [line for line in reader]

if(epoch_end is None):
epoch_end = math.inf

epoch_start = max(epoch_start, 0)
epoch_start = min(epoch_start, epoch_end)

for line in lines:
epoch = line[EPOCH_IDX]
lr = line[LR_IDX]
accuracy = line[EVAL_ACC_IDX]
loss = line[EVAL_LOSS_IDX]

if(int(epoch) >= epoch_start and int(epoch) < epoch_end):
accuracy_arr.append(float(accuracy))
loss_arr.append(float(loss))
epoch_count.append(int(epoch))
lr_arr.append(float(lr))

loss_arrs.append(loss_arr)
accuracy_arrs.append(accuracy_arr)
epoch_counts.append(epoch_count)
lrs.append(lr_arr)

if(output_dir is not None):
try:
os.mkdir(output_dir)
except OSError:
print ("Creation of the directory %s failed" % output_dir)
else:
print ("Successfully created the directory %s" % output_dir)

##### LOSS #####
for i in range(len(loss_arrs)):
if(model_names is None):
name = None
else:
name = model_names[i]

#Create and save plots to output folder
plt.plot(epoch_counts[i], loss_arrs[i], label=name)
plt.title("Loss Results")
plt.ylabel('Loss (Cross Entropy)')
plt.xlabel('Epochs')
fig1 = plt.gcf()

plt.legend(loc="upper left")

if(output_dir is not None):
fig1.savefig(os.path.join(output_dir, 'loss_graph.png'))

plt.show()

##### ACCURACY #####
for i in range(len(accuracy_arrs)):
if(model_names is None):
name = None
else:
name = model_names[i]

#Create and save plots to output folder
plt.plot(epoch_counts[i], accuracy_arrs[i], label=name)
plt.title("Accuracy Results")
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
fig2 = plt.gcf()

plt.legend(loc="upper left")

if(output_dir is not None):
fig2.savefig(os.path.join(output_dir, 'accuracy_graph.png'))

plt.show()

##### LR #####
for i in range(len(lrs)):
if(model_names is None):
name = None
else:
name = model_names[i]

#Create and save plots to output folder
plt.plot(epoch_counts[i], lrs[i], label=name)
plt.title("Learn Rate Results")
plt.ylabel('Learn Rate')
plt.xlabel('Epochs')
fig2 = plt.gcf()

plt.legend(loc="upper left")

if(output_dir is not None):
fig2.savefig(os.path.join(output_dir, 'lr_graph.png'))

plt.show()

# graph_results_legacy
def graph_results_legacy(input_dirs="./saved_models/results", output_dir=None, model_names=None, epoch_start=0, epoch_end=None):
"""
----------
Author: Ben Myrick
Modified: Damon Gwinn
----------
Graphs model training and evaluation data
Graphs model training and evaluation data using the old results format (legacy)
----------
"""

input_dirs = input_dirs.split(':')
input_dirs = input_dirs.split(SPLITTER)

if(model_names is not None):
model_names = model_names.split(':')
model_names = model_names.split(SPLITTER)
if(len(model_names) != len(input_dirs)):
print("Error: len(model_names) != len(input_dirs)")
return
Expand Down Expand Up @@ -121,11 +262,12 @@ def parse_args():

parser = argparse.ArgumentParser()

parser.add_argument("-input_dirs", type=str, default="./saved_models/results", help="Input results folder from trained model ('results' folder). Seperate with ':' for comparisons between models")
parser.add_argument("-input_dirs", type=str, default="./saved_models/results", help="Input results folder from trained model ('results' folder). Seperate with '?' symbol for comparisons between models")
parser.add_argument("-output_dir", type=str, default=None, help="Optional output folder to save graph pngs")
parser.add_argument("-model_names", type=str, default=None, help="Names to display when color coding, seperate with ':'.")
parser.add_argument("-epoch_start", type=int, default=0, help="Epoch start. Defaults to first file.")
parser.add_argument("-epoch_end", type=int, default=None, help="Epoch end (non-inclusive). Defaults to None.")
parser.add_argument("--legacy", action="store_true", help="Use legacy results output format (you likely don't need this)")

return parser.parse_args()

Expand All @@ -141,7 +283,10 @@ def main():

args = parse_args()

graph_results(args.input_dirs, args.output_dir, args.model_names, args.epoch_start, args.epoch_end)
if(not args.legacy):
graph_results(args.input_dirs, args.output_dir, args.model_names, args.epoch_start, args.epoch_end)
else:
graph_results_legacy(args.input_dirs, args.output_dir, args.model_names, args.epoch_start, args.epoch_end)

if __name__ == "__main__":
main()
Loading

0 comments on commit a3b3c90

Please sign in to comment.