Skip to content

Commit 44f81c6

Browse files
committed
initial commit
0 parents  commit 44f81c6

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__/
2+
data/
3+
venv/

generate_sequences.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import os
3+
import torch
4+
import torch.utils.data as data
5+
6+
from utils import ensure_path, remove_path
7+
8+
DEFAULT_SEQUENCE_LENGTH = 50
9+
NUM_SEQUENCES = 100000
10+
11+
class XORDataset(data.Dataset):
12+
data_folder = './data'
13+
14+
def __init__(self, train=True, test_size=0.2):
15+
self._test_size = test_size
16+
self.train = train
17+
18+
self.ensure_sequences()
19+
20+
filename = 'train.pt' if self.train else 'test.pt'
21+
self.features, self.labels = torch.load(f'{self.data_folder}/{filename}')
22+
23+
def ensure_sequences(self):
24+
if os.path.exists(self.data_folder):
25+
return
26+
27+
ensure_path(self.data_folder)
28+
29+
features, labels = generate_random_sequences()
30+
31+
test_start = int(len(features) * (1 - self._test_size))
32+
33+
train_set = (features[:test_start], labels[:test_start])
34+
test_set = (features[test_start:], labels[test_start:])
35+
36+
with open(f'{self.data_folder}/train.pt', 'wb') as file:
37+
torch.save(train_set, file)
38+
39+
with open(f'{self.data_folder}/test.pt', 'wb') as file:
40+
torch.save(test_set, file)
41+
42+
def __getitem__(self, index):
43+
return self.features[:, index], self.labels[index]
44+
45+
def __len__(self):
46+
return len(self.features)
47+
48+
# Data dimensions: [sequence_length, num_sequences, num_features]
49+
def generate_random_sequences(sequence_length=DEFAULT_SEQUENCE_LENGTH, num_sequences=NUM_SEQUENCES):
50+
# generates num_sequences random bit sequences of length
51+
# extra dimension is num_features for pytorch, in this case 1
52+
sequences = np.random.randint(2, size=(sequence_length, num_sequences, 1))
53+
54+
# if total number of ones is odd, odd parity bit set to 1, otherwise 0
55+
parity = (sequences.sum(axis=0) % 2 != 0).astype(int)
56+
57+
return sequences, parity
58+
59+
if __name__ == '__main__':
60+
XORDataset(test_size=0.2)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
numpy==1.15.0
2+
torch==0.4.1

utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
import shutil
3+
4+
def ensure_path(path):
5+
"""Create the path if it does not exist
6+
"""
7+
if not os.path.exists(path):
8+
os.makedirs(path)
9+
return path
10+
11+
def remove_path(path):
12+
"""Remove the path if it exists."""
13+
if os.path.exists(path):
14+
shutil.rmtree(path)

0 commit comments

Comments
 (0)