-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathauto_split.py
43 lines (33 loc) · 1.2 KB
/
auto_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
### author: Aashis Khanal
### sraashis@gmail.com
### date: 9/10/2018
"""
import json
def load_split_json(json_file):
try:
f = open(json_file)
f = json.load(f)
print('### SPLIT FOUND: ', json_file + ' Loaded')
return f
except:
print(json_file + ' FILE NOT LOADED !!!')
def create_splits(files, k=0, json_file='SPLIT', shuffle_files=True):
from random import shuffle
from itertools import chain
import numpy as np
json_file = json_file.split('.')[0]
if shuffle_files:
shuffle(files)
ix_splits = np.array_split(np.arange(len(files)), k)
for i in range(len(ix_splits)):
test_ix = ix_splits[i].tolist()
val_ix = ix_splits[(i + 1) % len(ix_splits)].tolist()
train_ix = [ix for ix in np.arange(len(files)) if ix not in test_ix + val_ix]
splits = {'train': [files[ix] for ix in train_ix],
'validation': [files[ix] for ix in val_ix],
'test': [files[ix] for ix in test_ix]}
print('Valid:', set(files) - set(list(chain(*splits.values()))) == set([]))
f = open(json_file + '_' + str(i) + '.json', "w")
f.write(json.dumps(splits))
f.close()