1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ import torch .utils .data as data
5
+
6
+ from utils import ensure_path , remove_path
7
+
8
+ DEFAULT_SEQUENCE_LENGTH = 50
9
+ NUM_SEQUENCES = 100000
10
+
11
+ class XORDataset (data .Dataset ):
12
+ data_folder = './data'
13
+
14
+ def __init__ (self , train = True , test_size = 0.2 ):
15
+ self ._test_size = test_size
16
+ self .train = train
17
+
18
+ self .ensure_sequences ()
19
+
20
+ filename = 'train.pt' if self .train else 'test.pt'
21
+ self .features , self .labels = torch .load (f'{ self .data_folder } /{ filename } ' )
22
+
23
+ def ensure_sequences (self ):
24
+ if os .path .exists (self .data_folder ):
25
+ return
26
+
27
+ ensure_path (self .data_folder )
28
+
29
+ features , labels = generate_random_sequences ()
30
+
31
+ test_start = int (len (features ) * (1 - self ._test_size ))
32
+
33
+ train_set = (features [:test_start ], labels [:test_start ])
34
+ test_set = (features [test_start :], labels [test_start :])
35
+
36
+ with open (f'{ self .data_folder } /train.pt' , 'wb' ) as file :
37
+ torch .save (train_set , file )
38
+
39
+ with open (f'{ self .data_folder } /test.pt' , 'wb' ) as file :
40
+ torch .save (test_set , file )
41
+
42
+ def __getitem__ (self , index ):
43
+ return self .features [:, index ], self .labels [index ]
44
+
45
+ def __len__ (self ):
46
+ return len (self .features )
47
+
48
+ # Data dimensions: [sequence_length, num_sequences, num_features]
49
+ def generate_random_sequences (sequence_length = DEFAULT_SEQUENCE_LENGTH , num_sequences = NUM_SEQUENCES ):
50
+ # generates num_sequences random bit sequences of length
51
+ # extra dimension is num_features for pytorch, in this case 1
52
+ sequences = np .random .randint (2 , size = (sequence_length , num_sequences , 1 ))
53
+
54
+ # if total number of ones is odd, odd parity bit set to 1, otherwise 0
55
+ parity = (sequences .sum (axis = 0 ) % 2 != 0 ).astype (int )
56
+
57
+ return sequences , parity
58
+
59
+ if __name__ == '__main__' :
60
+ XORDataset (test_size = 0.2 )
0 commit comments