Skip to content

Commit ff5fba5

Browse files
committed
simplify code by generating database every time and not caching it
since we set the random seed it is deterministic and very fast to generate. so caching is unnecessary.
1 parent 390a714 commit ff5fba5

File tree

3 files changed

+6
-53
lines changed

3 files changed

+6
-53
lines changed

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def train(params: ModelParams):
6161

6262
optimizer = torch.optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum)
6363
loss_fn = torch.nn.BCEWithLogitsLoss()
64-
train_loader = DataLoader(XORDataset(), batch_size=params.batch_size, shuffle=True)
65-
test_loader = DataLoader(XORDataset(train=False), batch_size=params.batch_size)
64+
train_loader = DataLoader(XORDataset(), batch_size=params.batch_size)
65+
# test separately generated xor
66+
test_loader = DataLoader(XORDataset(), batch_size=params.batch_size)
6667

6768
step = 0
6869
epoch = 1

utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,3 @@ def register_parser_types(parser, params_named_tuple):
1616

1717
for key, _type in hints.items():
1818
parser.add_argument(f'--{key}', type=_type, default=defaults.get(key))
19-
20-
21-
# ------------------------- Path Utils -------------------------
22-
23-
24-
def ensure_path(path):
25-
"""Create the path if it does not exist"""
26-
if not os.path.exists(path):
27-
os.makedirs(path)
28-
return path
29-
30-
31-
def remove_path(path):
32-
"""Remove the path if it exists."""
33-
if os.path.exists(path):
34-
shutil.rmtree(path)

xor_dataset.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import torch
44
import torch.utils.data as data
55

6-
from utils import ensure_path, remove_path
7-
86
DEFAULT_NUM_BITS = 50
97
DEFAULT_NUM_SEQUENCES = 100000
108

@@ -14,41 +12,16 @@
1412
class XORDataset(data.Dataset):
1513
data_folder = './data'
1614

17-
def __init__(self, train=True, test_size=0.2):
18-
self._test_size = test_size
19-
self.train = train
20-
21-
# cache dataset so training is deterministic
22-
self.ensure_sequences()
23-
24-
filename = 'train.pt' if self.train else 'test.pt'
25-
self.features, self.labels = torch.load(f'{self.data_folder}/{filename}')
15+
def __init__(self):
16+
self.features, self.labels = get_random_bits_parity()
2617

2718
# expand the dimensions for the lstm
2819
# [batch, bits] -> [batch, bits, 1]
2920
self.features = np.expand_dims(self.features, -1)
21+
3022
# [batch, parity] -> [batch, parity, 1]
3123
self.labels = np.expand_dims(self.labels, -1)
3224

33-
def ensure_sequences(self):
34-
if os.path.exists(self.data_folder):
35-
return
36-
37-
ensure_path(self.data_folder)
38-
39-
features, labels = get_random_bits_parity()
40-
41-
test_start = int(len(features) * (1 - self._test_size))
42-
43-
train_set = (features[:test_start], labels[:test_start])
44-
test_set = (features[test_start:], labels[test_start:])
45-
46-
with open(f'{self.data_folder}/train.pt', 'wb') as file:
47-
torch.save(train_set, file)
48-
49-
with open(f'{self.data_folder}/test.pt', 'wb') as file:
50-
torch.save(test_set, file)
51-
5225
def __getitem__(self, index):
5326
return self.features[index, :], self.labels[index]
5427

@@ -73,8 +46,3 @@ def get_random_bits_parity(num_sequences=DEFAULT_NUM_SEQUENCES, num_bits=DEFAULT
7346
parity = bitsum % 2 != 0
7447

7548
return bit_sequences.astype('float32'), parity.astype('float32')
76-
77-
78-
if __name__ == '__main__':
79-
remove_path(XORDataset.data_folder)
80-
XORDataset(test_size=0.2)

0 commit comments

Comments
 (0)