Skip to content

Commit 33e9b73

Browse files
author
Ozdenizci Ozan
committed
update ocm dataset loaders
1 parent c7f2eac commit 33e9b73

File tree

3 files changed

+18
-45
lines changed

3 files changed

+18
-45
lines changed

datasets/cifar10.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,12 @@ def __init__(self, args, normalize=True):
2323
self.tr_test = transforms.Compose(self.tr_test)
2424

2525
if args.ocm:
26-
if args.randcode:
27-
self.C = np.ones((args.num_classes, args.code_length), np.float32)
28-
for c in range(args.num_classes):
29-
how_much_negs = int(args.code_length // 4) + np.random.choice(int(args.code_length // 2))
30-
picks = np.random.choice(args.code_length, how_much_negs, replace=False)
31-
self.C[c, picks] = -1
32-
print(self.C)
33-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
34-
else:
35-
self.C = hadamard(args.code_length).astype(np.float32)
36-
self.C = np.delete(self.C, 0, axis=0)
37-
np.random.shuffle(self.C)
38-
self.C = self.C[:args.num_classes]
39-
print(self.C)
40-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
26+
self.C = hadamard(args.code_length).astype(np.float32)
27+
self.C = np.delete(self.C, 0, axis=0)
28+
np.random.shuffle(self.C)
29+
self.C = self.C[:args.num_classes]
30+
print(self.C)
31+
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
4132
else:
4233
self.tr_target = [transforms.Lambda(lambda y: y)]
4334

datasets/cifar100.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,12 @@ def __init__(self, args, normalize=True):
2323
self.tr_test = transforms.Compose(self.tr_test)
2424

2525
if args.ocm:
26-
if args.randcode:
27-
self.C = np.ones((args.num_classes, args.code_length), np.float32)
28-
for c in range(args.num_classes):
29-
how_much_negs = int(args.code_length // 4) + np.random.choice(int(args.code_length // 2))
30-
picks = np.random.choice(args.code_length, how_much_negs, replace=False)
31-
self.C[c, picks] = -1
32-
print(self.C)
33-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
34-
else:
35-
self.C = hadamard(args.code_length).astype(np.float32)
36-
self.C = np.delete(self.C, 0, axis=0)
37-
np.random.shuffle(self.C)
38-
self.C = self.C[:args.num_classes]
39-
print(self.C)
40-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
26+
self.C = hadamard(args.code_length).astype(np.float32)
27+
self.C = np.delete(self.C, 0, axis=0)
28+
np.random.shuffle(self.C)
29+
self.C = self.C[:args.num_classes]
30+
print(self.C)
31+
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
4132
else:
4233
self.tr_target = [transforms.Lambda(lambda y: y)]
4334

datasets/imagenet.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,12 @@ def __init__(self, args, normalize=True):
2525
self.tr_test = transforms.Compose(self.tr_test)
2626

2727
if args.ocm:
28-
if args.randcode:
29-
self.C = np.ones((args.num_classes, args.code_length), np.float32)
30-
for c in range(args.num_classes):
31-
how_much_negs = int(args.code_length // 4) + np.random.choice(int(args.code_length // 2))
32-
picks = np.random.choice(args.code_length, how_much_negs, replace=False)
33-
self.C[c, picks] = -1
34-
print(self.C)
35-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
36-
else:
37-
self.C = hadamard(args.code_length).astype(np.float32)
38-
self.C = np.delete(self.C, 0, axis=0)
39-
np.random.shuffle(self.C)
40-
self.C = self.C[:args.num_classes]
41-
print(self.C)
42-
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
28+
self.C = hadamard(args.code_length).astype(np.float32)
29+
self.C = np.delete(self.C, 0, axis=0)
30+
np.random.shuffle(self.C)
31+
self.C = self.C[:args.num_classes]
32+
print(self.C)
33+
self.tr_target = [transforms.Lambda(lambda y: torch.LongTensor(self.C[y]))]
4334
else:
4435
self.tr_target = [transforms.Lambda(lambda y: y)]
4536

0 commit comments

Comments
 (0)