Skip to content

Commit

Permalink
add dataset path to config file, add missing dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
psandovalsegura committed Jun 23, 2022
1 parent e6b0c7e commit f861cb8
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 19 deletions.
13 changes: 7 additions & 6 deletions config/base.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
misc:
project_name: tmp-project-name
run_name: tmp-run-name
log_model: False # whether to have wandb save artifacts
log_model: False # whether to have wandb save artifacts
use_auto_scale_batch_size: False
enable_progress_bar: False
enable_checkpointing: False # whether to save model checkpoints
enable_checkpointing: False # whether to save model checkpoints
log_every_n_steps: 50
wandb_save_dir: '/vulcanscratch/psando/wandb'
dirpath: '/vulcanscratch/psando/poison_ckpts' # path to save checkpoints
wandb_save_dir: '/vulcanscratch/psando/wandb' # path to save wandb files
dirpath: '/vulcanscratch/psando/poison_ckpts' # path to save your checkpoints

train:
model_name: ResNet18
dataset: CIFAR10
dataset: CIFAR10 # either 'CIFAR10' 'CIFAR100' 'STL10' or 'SVHN'
dataset_path: '/vulcanscratch/psando/cifar-10/' # path to your dataset root
batch_size: 128
epochs: 100
num_workers: 16
Expand All @@ -21,4 +22,4 @@ train:
adversarial_poison_path: False
unlearnable_poison_path: False
dataset_path: False
augmentations_key: 'none' # either 'none' 'cutout' 'cutmix' or 'mixup'
augmentations_key: 'none' # either 'none' 'cutout' 'cutmix' or 'mixup'
8 changes: 5 additions & 3 deletions lightning_modules/lightning_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
momentum=0.9,
adversarial_poison_path=False,
unlearnable_poison_path=False,
base_dataset_path=None,
augmentations_key=None):
super().__init__()
self.model = get_model_class_from_name(model_name=model_name)
Expand All @@ -27,6 +28,7 @@ def __init__(self,
self.momentum = momentum
self.adversarial_poison_path = adversarial_poison_path
self.unlearnable_poison_path = unlearnable_poison_path
self.base_dataset_path = base_dataset_path
self.augmentations_key = augmentations_key
self.loss_fn = self.configure_criterion()
self.save_hyperparameters()
Expand Down Expand Up @@ -92,7 +94,7 @@ def train_dataloader(self):
])
transform_train = self.configure_transform(transform_train)
trainset = datasets.CIFAR10(
root='/vulcanscratch/psando/cifar-10/', train=True, download=False, transform=transform_train)
root=self.base_dataset_path, train=True, download=False, transform=transform_train)
if self.adversarial_poison_path:
trainset = AdversarialPoison(root=self.adversarial_poison_path, baseset=trainset)
if self.unlearnable_poison_path:
Expand All @@ -107,7 +109,7 @@ def val_dataloader(self):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = datasets.CIFAR10(
root='/vulcanscratch/psando/cifar-10/', train=False, download=False, transform=transform_test)
root=self.base_dataset_path, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand All @@ -117,7 +119,7 @@ def test_dataloader(self):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = datasets.CIFAR10(
root='/vulcanscratch/psando/cifar-10/', train=False, download=False, transform=transform_test)
root=self.base_dataset_path, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand Down
8 changes: 5 additions & 3 deletions lightning_modules/lightning_cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
momentum=0.9,
adversarial_poison_path=False,
unlearnable_poison_path=False,
base_dataset_path=None,
augmentations_key=None):
super().__init__()
self.model = get_model_class_from_name(model_name=model_name)
Expand All @@ -27,6 +28,7 @@ def __init__(self,
self.momentum = momentum
self.adversarial_poison_path = adversarial_poison_path
self.unlearnable_poison_path = unlearnable_poison_path
self.base_dataset_path = base_dataset_path
self.augmentations_key = augmentations_key
self.loss_fn = self.configure_criterion()
self.save_hyperparameters()
Expand Down Expand Up @@ -93,7 +95,7 @@ def train_dataloader(self):
])
transform_train = self.configure_transform(transform_train)
trainset = datasets.CIFAR100(
root='/vulcanscratch/psando/cifar-10/', train=True, download=False, transform=transform_train)
root=self.base_dataset_path, train=True, download=False, transform=transform_train)
if self.adversarial_poison_path:
trainset = AdversarialPoison(root=self.adversarial_poison_path, baseset=trainset)
if self.unlearnable_poison_path:
Expand All @@ -108,7 +110,7 @@ def val_dataloader(self):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = datasets.CIFAR100(
root='/vulcanscratch/psando/cifar-10/', train=False, download=False, transform=transform_test)
root=self.base_dataset_path, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand All @@ -118,7 +120,7 @@ def test_dataloader(self):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = datasets.CIFAR100(
root='/vulcanscratch/psando/cifar-10/', train=False, download=False, transform=transform_test)
root=self.base_dataset_path, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand Down
8 changes: 5 additions & 3 deletions lightning_modules/lightning_stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
momentum=0.9,
adversarial_poison_path=False,
unlearnable_poison_path=False,
base_dataset_path=None,
augmentations_key=None):
super().__init__()
self.model = get_model_class_from_name(model_name=model_name)
Expand All @@ -27,6 +28,7 @@ def __init__(self,
self.momentum = momentum
self.adversarial_poison_path = adversarial_poison_path
self.unlearnable_poison_path = unlearnable_poison_path
self.base_dataset_path = base_dataset_path
self.augmentations_key = augmentations_key
self.loss_fn = self.configure_criterion()
self.save_hyperparameters()
Expand Down Expand Up @@ -93,7 +95,7 @@ def train_dataloader(self):
transforms.Normalize((0.44671047,0.43981034,0.40664658), (0.26034108, 0.25657734, 0.27126735)),
])
transform_train = self.configure_transform(transform_train)
trainset = datasets.STL10(root='/vulcanscratch/psando/STL', split='train', download=False, transform=transform_train)
trainset = datasets.STL10(root=self.base_dataset_path, split='train', download=False, transform=transform_train)
if self.adversarial_poison_path:
trainset = AdversarialPoison(root=self.adversarial_poison_path, baseset=trainset)
if self.unlearnable_poison_path:
Expand All @@ -107,7 +109,7 @@ def val_dataloader(self):
transforms.ToTensor(),
transforms.Normalize((0.44671047, 0.43981034, 0.40664658), (0.26034108, 0.25657734, 0.27126735)),
])
testset = datasets.STL10(root='/vulcanscratch/psando/STL', split='test', download=False, transform=transform_test)
testset = datasets.STL10(root=self.base_dataset_path, split='test', download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand All @@ -116,7 +118,7 @@ def test_dataloader(self):
transforms.ToTensor(),
transforms.Normalize((0.44671047, 0.43981034, 0.40664658), (0.26034108, 0.25657734, 0.27126735)),
])
testset = datasets.STL10(root='/vulcanscratch/psando/STL', split='test', download=False, transform=transform_test)
testset = datasets.STL10(root=self.base_dataset_path, split='test', download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand Down
8 changes: 5 additions & 3 deletions lightning_modules/lightning_svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
momentum=0.9,
adversarial_poison_path=False,
unlearnable_poison_path=False,
base_dataset_path=None,
augmentations_key=None):
super().__init__()
self.model = get_model_class_from_name(model_name=model_name)
Expand All @@ -27,6 +28,7 @@ def __init__(self,
self.momentum = momentum
self.adversarial_poison_path = adversarial_poison_path
self.unlearnable_poison_path = unlearnable_poison_path
self.base_dataset_path = base_dataset_path
self.augmentations_key = augmentations_key
self.loss_fn = self.configure_criterion()
self.save_hyperparameters()
Expand Down Expand Up @@ -92,7 +94,7 @@ def train_dataloader(self):
transforms.Normalize((0.43768218, 0.44376934, 0.47280428), (0.1980301, 0.2010157, 0.19703591)),
])
transform_train = self.configure_transform(transform_train)
trainset = datasets.SVHN(root='/vulcanscratch/psando/SVHN', split='train', download=False, transform=transform_train)
trainset = datasets.SVHN(root=self.base_dataset_path, split='train', download=False, transform=transform_train)
if self.adversarial_poison_path:
trainset = AdversarialPoison(root=self.adversarial_poison_path, baseset=trainset)
if self.unlearnable_poison_path:
Expand All @@ -106,7 +108,7 @@ def val_dataloader(self):
transforms.ToTensor(),
transforms.Normalize((0.43768218, 0.44376934, 0.47280428), (0.1980301, 0.2010157, 0.19703591)),
])
testset = datasets.SVHN(root='/vulcanscratch/psando/SVHN', split='test', download=False, transform=transform_test)
testset = datasets.SVHN(root=self.base_dataset_path, split='test', download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand All @@ -115,7 +117,7 @@ def test_dataloader(self):
transforms.ToTensor(),
transforms.Normalize((0.43768218, 0.44376934, 0.47280428), (0.1980301, 0.2010157, 0.19703591)),
])
testset = datasets.SVHN(root='/vulcanscratch/psando/SVHN', split='test', download=False, transform=transform_test)
testset = datasets.SVHN(root=self.base_dataset_path, split='test', download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return testloader

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ numpy
matplotlib
torch
pillow
scikit-learn
scikit-learn
einops
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def main(cfg : DictConfig) -> None:
momentum=cfg.train.momentum,
adversarial_poison_path=cfg.train.adversarial_poison_path,
unlearnable_poison_path=cfg.train.unlearnable_poison_path,
base_dataset_path=cfg.train.dataset_path,
augmentations_key=cfg.train.augmentations_key)
elif cfg.train.dataset == 'CIFAR100':
model = LitCIFAR100Model(model_name=cfg.train.model_name,
Expand All @@ -32,6 +33,7 @@ def main(cfg : DictConfig) -> None:
momentum=cfg.train.momentum,
adversarial_poison_path=cfg.train.adversarial_poison_path,
unlearnable_poison_path=cfg.train.unlearnable_poison_path,
base_dataset_path=cfg.train.dataset_path,
augmentations_key=cfg.train.augmentations_key)
elif cfg.train.dataset == 'STL10':
model = LitSTLModel(model_name=cfg.train.model_name,
Expand All @@ -42,6 +44,7 @@ def main(cfg : DictConfig) -> None:
momentum=cfg.train.momentum,
adversarial_poison_path=cfg.train.adversarial_poison_path,
unlearnable_poison_path=cfg.train.unlearnable_poison_path,
base_dataset_path=cfg.train.dataset_path,
augmentations_key=cfg.train.augmentations_key)
elif cfg.train.dataset == 'SVHN':
model = LitSVHNModel(model_name=cfg.train.model_name,
Expand All @@ -52,6 +55,7 @@ def main(cfg : DictConfig) -> None:
momentum=cfg.train.momentum,
adversarial_poison_path=cfg.train.adversarial_poison_path,
unlearnable_poison_path=cfg.train.unlearnable_poison_path,
base_dataset_path=cfg.train.dataset_path,
augmentations_key=cfg.train.augmentations_key)
else:
raise ValueError(f"Dataset {cfg.train.dataset} not supported.")
Expand Down

0 comments on commit f861cb8

Please sign in to comment.