Skip to content

Commit cd8a9da

Browse files
committed
save dataset as float32 for pytorch
- remove the xor dataset when running the program as main - fix __getitem__ indexing since we changed dims to be [batch, bits]
1 parent 4765172 commit cd8a9da

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

xor_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def ensure_sequences(self):
4242
torch.save(test_set, file)
4343

4444
def __getitem__(self, index):
45-
return self.features[:, index], self.labels[index]
45+
return self.features[index, :], self.labels[index]
4646

4747
def __len__(self):
4848
return len(self.features)
@@ -58,10 +58,11 @@ def get_random_bits_parity(num_sequences=DEFAULT_NUM_SEQUENCES, num_bits=DEFAULT
5858

5959
# if total number of ones is odd, set even parity bit to 1, otherwise 0
6060
# https://en.wikipedia.org/wiki/Parity_bit
61-
parity = (bit_sequences.sum(axis=1) % 2 != 0).astype(int)
61+
parity = (bit_sequences.sum(axis=1) % 2 != 0)
6262

63-
return bit_sequences, parity
63+
return bit_sequences.astype('float32'), parity.astype('float32')
6464

6565

6666
if __name__ == '__main__':
67+
remove_path(XORDataset.data_folder)
6768
XORDataset(test_size=0.2)

0 commit comments

Comments
 (0)