-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
executable file
·113 lines (85 loc) · 3.44 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
import os
import argparse
import logging
import json
from sklearn.model_selection import train_test_split
import random
from collections import defaultdict, namedtuple
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
SPLITS = ["train", "validation", "test"]
class MidiDataset:
def __init__(self, path, max_files=None):
super().__init__()
self.data = None
self.path = path
# Maximum files that will be processed
self.max_files = max_files
def load(self, splits, path=None):
block_size = 1024
all_blocks = []
processed_files = 0
walk_dir = os.path.abspath(self.path)
for r, _, files in os.walk(walk_dir):
for f in files:
path = f"{r}/{f}"
logger.info('Processing = ' + path)
with open(path, 'r') as midi_txt_file:
try:
line = midi_txt_file.readline()
except:
logger.error(f"Error reading file {path} (ignoring)")
continue
tokens = line.split(' ')
while len(tokens) > 0:
block = tokens[0 : min(block_size, len(tokens))]
all_blocks.append(' '.join(block))
tokens = tokens[block_size:]
processed_files += 1
if self.max_files is not None and processed_files >= self.max_files:
break
train, val_test = train_test_split(all_blocks, test_size=0.4)
validation, test = train_test_split(val_test, test_size=0.5)
self.data = {"train": train, "validation": validation, "test": test}
class Preprocessor:
def __init__(self, dataset, out_dirname):
self.dataset = dataset
self.out_dirname = out_dirname
def process(self, split):
output = {"data": []}
data = self.dataset.data[split]
for entry in data:
examples = []
examples.append({"in": entry})
for example in examples:
output["data"].append(example)
with open(os.path.join(self.out_dirname, f"{split}.json"), "w") as f:
json.dump(output, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=True, help="Base directory of the midi files converted to text")
parser.add_argument("--output", type=str, required=True, help="Name of the output directory")
parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process")
args = parser.parse_args()
logger.info(args)
dataset = MidiDataset(path=args.input, max_files=args.max_files)
try:
dataset.load(splits=SPLITS, path=args.output)
except FileNotFoundError as err:
logger.error(f"Dataset could not be loaded")
raise err
try:
out_dirname = args.output
os.makedirs(out_dirname, exist_ok=True)
except OSError as err:
logger.error(f"Output directory {out_dirname} can not be created")
raise err
preprocessor = Preprocessor(dataset=dataset, out_dirname=out_dirname)
for split in SPLITS:
preprocessor.process(split)
logger.info(f"Preprocessing finished.")