3
3
import torch
4
4
import torch .utils .data as data
5
5
6
- from utils import ensure_path , remove_path
7
-
8
6
DEFAULT_NUM_BITS = 50
9
7
DEFAULT_NUM_SEQUENCES = 100000
10
8
14
12
class XORDataset (data .Dataset ):
15
13
data_folder = './data'
16
14
17
- def __init__ (self , train = True , test_size = 0.2 ):
18
- self ._test_size = test_size
19
- self .train = train
20
-
21
- # cache dataset so training is deterministic
22
- self .ensure_sequences ()
23
-
24
- filename = 'train.pt' if self .train else 'test.pt'
25
- self .features , self .labels = torch .load (f'{ self .data_folder } /{ filename } ' )
15
+ def __init__ (self ):
16
+ self .features , self .labels = get_random_bits_parity ()
26
17
27
18
# expand the dimensions for the lstm
28
19
# [batch, bits] -> [batch, bits, 1]
29
20
self .features = np .expand_dims (self .features , - 1 )
21
+
30
22
# [batch, parity] -> [batch, parity, 1]
31
23
self .labels = np .expand_dims (self .labels , - 1 )
32
24
33
- def ensure_sequences (self ):
34
- if os .path .exists (self .data_folder ):
35
- return
36
-
37
- ensure_path (self .data_folder )
38
-
39
- features , labels = get_random_bits_parity ()
40
-
41
- test_start = int (len (features ) * (1 - self ._test_size ))
42
-
43
- train_set = (features [:test_start ], labels [:test_start ])
44
- test_set = (features [test_start :], labels [test_start :])
45
-
46
- with open (f'{ self .data_folder } /train.pt' , 'wb' ) as file :
47
- torch .save (train_set , file )
48
-
49
- with open (f'{ self .data_folder } /test.pt' , 'wb' ) as file :
50
- torch .save (test_set , file )
51
-
52
25
def __getitem__ (self , index ):
53
26
return self .features [index , :], self .labels [index ]
54
27
@@ -73,8 +46,3 @@ def get_random_bits_parity(num_sequences=DEFAULT_NUM_SEQUENCES, num_bits=DEFAULT
73
46
parity = bitsum % 2 != 0
74
47
75
48
return bit_sequences .astype ('float32' ), parity .astype ('float32' )
76
-
77
-
78
- if __name__ == '__main__' :
79
- remove_path (XORDataset .data_folder )
80
- XORDataset (test_size = 0.2 )
0 commit comments