Skip to content

Commit

Permalink
Move from Argparse to Omegaconf (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi authored Sep 27, 2022
1 parent be3de80 commit 6f228fd
Show file tree
Hide file tree
Showing 254 changed files with 7,657 additions and 6,108 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
key: ${{ runner.os }}

- name: pytest
run: pytest --cov=solo tests/args tests/backbones tests/losses tests/methods tests/utils
run: pytest --cov=solo tests/args tests/backbones tests/data tests/losses tests/methods tests/utils

- name: Statistics
if: success()
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ runs/
# wandb dir
wandb/
wandb*/
!scripts/**

# umap dir
auto_umap/
Expand Down Expand Up @@ -46,6 +47,8 @@ lightning_logs/
*logs*/
*output*/

.hydra

# Created by https://www.gitignore.io/api/python,visualstudiocode
# Edit at https://www.gitignore.io/?templates=python,visualstudiocode

Expand Down
29 changes: 25 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The library is self-contained, but it is possible to use the models outside of s
---

## News
* **[Sep 27 2022]**: :pencil: Brand new config system using OmegaConf/Hydra. Adds more clarity and flexibility. New tutorials will follow soon!
* **[Aug 04 2022]**: :paintbrush: Added [MAE](https://arxiv.org/abs/2111.06377) and supports finetuning the backbone with `main_linear.py`, mixup, cutmix and [random augment](https://arxiv.org/abs/1909.13719).
* **[Jul 13 2022]**: :sparkling_heart: Added support for [H5](https://docs.h5py.org/en/stable/index.html) data, improved scripts and data handling.
* **[Jun 26 2022]**: :fire: Added [MoCo V3](https://arxiv.org/abs/2104.02057).
Expand Down Expand Up @@ -44,7 +45,15 @@ The library is self-contained, but it is possible to use the models outside of s

---

## Methods available:
## Roadmap and help needed
* Redoing the documentation to improve clarity.
* Better and up-to-date tutorials.
* Add performance-related testing to ensure that methods perform the same across updates.
* Adding new methods (continuous effort).

---

## Methods available
* [Barlow Twins](https://arxiv.org/abs/2103.03230)
* [BYOL](https://arxiv.org/abs/2006.07733)
* [DeepCluster V2](https://arxiv.org/abs/2006.09882)
Expand Down Expand Up @@ -151,17 +160,29 @@ pre-commit install

## Training

For pretraining the backbone, follow one of the many bash files in `bash_files/pretrain/`.
For pretraining the backbone, follow one of the many bash files in `scripts/pretrain/`.
We are now using [Hydra](https://github.com/facebookresearch/hydra) to handle the config files, so the common syntax is something like:
```bash
python3 main_pretrain.py \
# path to training script folder
--config-path scripts/pretrain/imagenet-100/ \
# training config name
--config-name barlow.yaml
# add new arguments (e.g. those not defined in the yaml files)
# by doing ++new_argument=VALUE
# pytorch lightning's arguments can be added here as well.
```

After that, for offline linear evaluation, follow the examples in `bash_files/linear`.
After that, for offline linear evaluation, follow the examples in `scripts/linear` or `scripts/finetune` for finetuning the whole backbone.

There are extra experiments on K-NN evaluation in `bash_files/knn/` and feature visualization with UMAP in `bash_files/umap/`.
For k-NN evaluation and UMAP visualization check the scripts in `scripts/{knn,umap}`.

**NOTE:** Files try to be up-to-date and follow as closely as possible the recommended parameters of each paper, but check them before running.

---

## Tutorials

Please, check out our [documentation](https://solo-learn.readthedocs.io/en/latest) and tutorials:
* [Overview](https://solo-learn.readthedocs.io/en/latest/tutorials/overview.html)
* [Offline linear eval](https://solo-learn.readthedocs.io/en/latest/tutorials/offline_linear_eval.html)
Expand Down
7 changes: 7 additions & 0 deletions config/pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- _self_
- augmentations: null
- wandb: null

name: null
method: null
2 changes: 1 addition & 1 deletion main_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from solo.args.setup import parse_args_knn
from solo.args.knn import parse_args_knn
from solo.data.classification_dataloader import (
prepare_dataloaders,
prepare_datasets,
Expand Down
153 changes: 76 additions & 77 deletions main_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import inspect
import logging
import os

import hydra
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from solo.args.setup import parse_args_linear
from solo.args.linear import parse_cfg
from solo.data.classification_dataloader import prepare_data
from solo.methods.base import BaseMethod
from solo.methods.linear import LinearModel
Expand All @@ -45,38 +48,28 @@
_dali_avaliable = True


def main():
args = parse_args_linear()
@hydra.main(version_base="1.2")
def main(cfg: DictConfig):
# hydra doesn't allow us to add new keys for "safety"
# set_struct(..., False) disables this behavior and allows us to add more parameters
# without making the user specify every single thing about the model
OmegaConf.set_struct(cfg, False)
cfg = parse_cfg(cfg)

assert args.backbone in BaseMethod._BACKBONES
backbone_model = BaseMethod._BACKBONES[args.backbone]
backbone_model = BaseMethod._BACKBONES[cfg.backbone.name]

# initialize backbone
kwargs = args.backbone_args
cifar = kwargs.pop("cifar", False)
# swin specific
if "swin" in args.backbone and cifar:
kwargs["window_size"] = 4

if "vit" in args.backbone:
kwargs["drop_path_rate"] = args.drop_path
kwargs["global_pool"] = args.global_pool

method = args.pretrain_method
backbone = backbone_model(method=method, **kwargs)
if args.backbone.startswith("resnet"):
backbone = backbone_model(method=cfg.pretrain_method, **cfg.backbone.kwargs)
if cfg.backbone.name.startswith("resnet"):
# remove fc layer
backbone.fc = nn.Identity()
cifar = cfg.data.dataset in ["cifar10", "cifar100"]
if cifar:
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
backbone.maxpool = nn.Identity()

assert (
args.pretrained_feature_extractor.endswith(".ckpt")
or args.pretrained_feature_extractor.endswith(".pth")
or args.pretrained_feature_extractor.endswith(".pt")
)
ckpt_path = args.pretrained_feature_extractor
ckpt_path = cfg.pretrained_feature_extractor
assert ckpt_path.endswith(".ckpt") or ckpt_path.endswith(".pth") or ckpt_path.endswith(".pt")

state = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state.keys()):
Expand All @@ -93,63 +86,62 @@ def main():

# check if mixup or cutmix is enabled
mixup_func = None
mixup_active = args.mixup > 0 or args.cutmix > 0
mixup_active = cfg.mixup > 0 or cfg.cutmix > 0
if mixup_active:
logging.info("Mixup activated")
mixup_func = Mixup(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode="batch",
label_smoothing=args.label_smoothing,
num_classes=args.num_classes,
label_smoothing=cfg.label_smoothing,
num_classes=cfg.data.num_classes,
)
# smoothing is handled with mixup label transform
loss_func = SoftTargetCrossEntropy()
elif args.label_smoothing > 0:
loss_func = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
elif cfg.label_smoothing > 0:
loss_func = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing)
else:
loss_func = torch.nn.CrossEntropyLoss()

del args.backbone
model = LinearModel(backbone, loss_func=loss_func, mixup_func=mixup_func, **args.__dict__)
model = LinearModel(backbone, loss_func=loss_func, mixup_func=mixup_func, cfg=cfg)
make_contiguous(model)
# can provide up to ~20% speed up
if not args.no_channel_last:
if not cfg.performance.disable_channel_last:
model = model.to(memory_format=torch.channels_last)

if args.data_format == "dali":
if cfg.data.format == "dali":
val_data_format = "image_folder"
else:
val_data_format = args.data_format
val_data_format = cfg.data.format

train_loader, val_loader = prepare_data(
args.dataset,
train_data_path=args.train_data_path,
val_data_path=args.val_data_path,
cfg.data.dataset,
train_data_path=cfg.data.train_path,
val_data_path=cfg.data.val_path,
data_format=val_data_format,
batch_size=args.batch_size,
num_workers=args.num_workers,
auto_augment=args.auto_augment,
batch_size=cfg.optimizer.batch_size,
num_workers=cfg.data.num_workers,
auto_augment=cfg.auto_augment,
)

if args.data_format == "dali":
if cfg.data.format == "dali":
assert (
_dali_avaliable
), "Dali is not currently avaiable, please install it first with pip3 install .[dali]."

assert not args.auto_augment, "Auto augmentation is not supported with Dali."
assert not cfg.auto_augment, "Auto augmentation is not supported with Dali."

dali_datamodule = ClassificationDALIDataModule(
dataset=args.dataset,
train_data_path=args.train_data_path,
val_data_path=args.val_data_path,
num_workers=args.num_workers,
batch_size=args.batch_size,
data_fraction=args.data_fraction,
dali_device=args.dali_device,
dataset=cfg.data.dataset,
train_data_path=cfg.data.train_path,
val_data_path=cfg.data.val_path,
num_workers=cfg.data.num_workers,
batch_size=cfg.optimizer.batch_size,
data_fraction=cfg.data.fraction,
dali_device=cfg.dali.device,
)

# use normal torchvision dataloader for validation to save memory
Expand All @@ -158,59 +150,66 @@ def main():
# 1.7 will deprecate resume_from_checkpoint, but for the moment
# the argument is the same, but we need to pass it as ckpt_path to trainer.fit
ckpt_path, wandb_run_id = None, None
if args.auto_resume and args.resume_from_checkpoint is None:
if cfg.auto_resume.enabled and cfg.resume_from_checkpoint is None:
auto_resumer = AutoResumer(
checkpoint_dir=os.path.join(args.checkpoint_dir, "linear"),
max_hours=args.auto_resumer_max_hours,
checkpoint_dir=os.path.join(cfg.checkpoint.dir, "linear"),
max_hours=cfg.auto_resume.max_hours,
)
resume_from_checkpoint, wandb_run_id = auto_resumer.find_checkpoint(args)
resume_from_checkpoint, wandb_run_id = auto_resumer.find_checkpoint(cfg)
if resume_from_checkpoint is not None:
print(
"Resuming from previous checkpoint that matches specifications:",
f"'{resume_from_checkpoint}'",
)
ckpt_path = resume_from_checkpoint
elif args.resume_from_checkpoint is not None:
ckpt_path = args.resume_from_checkpoint
del args.resume_from_checkpoint
elif cfg.resume_from_checkpoint is not None:
ckpt_path = cfg.resume_from_checkpoint
del cfg.resume_from_checkpoint

callbacks = []

if args.save_checkpoint:
if cfg.checkpoint.enabled:
# save checkpoint on last epoch only
ckpt = Checkpointer(
args,
logdir=os.path.join(args.checkpoint_dir, "linear"),
frequency=args.checkpoint_frequency,
cfg,
logdir=os.path.join(cfg.checkpoint.dir, "linear"),
frequency=cfg.checkpoint.frequency,
keep_prev=cfg.checkpoint.keep_prev,
)
callbacks.append(ckpt)

# wandb logging
if args.wandb:
if cfg.wandb.enabled:
wandb_logger = WandbLogger(
name=args.name,
project=args.project,
entity=args.entity,
offline=args.offline,
name=cfg.name,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
offline=cfg.wandb.offline,
resume="allow" if wandb_run_id else None,
id=wandb_run_id,
)
wandb_logger.watch(model, log="gradients", log_freq=100)
wandb_logger.log_hyperparams(args)
wandb_logger.log_hyperparams(OmegaConf.to_container(cfg))

# lr logging
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)

trainer = Trainer.from_argparse_args(
args,
logger=wandb_logger if args.wandb else None,
callbacks=callbacks,
enable_checkpointing=False,
strategy=DDPStrategy(find_unused_parameters=False)
if args.strategy == "ddp"
else args.strategy,
trainer_kwargs = OmegaConf.to_container(cfg)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(Trainer.__init__).parameters
trainer_kwargs = {name: trainer_kwargs[name] for name in valid_kwargs if name in trainer_kwargs}
trainer_kwargs.update(
{
"logger": wandb_logger if cfg.wandb.enabled else None,
"callbacks": callbacks,
"enable_checkpointing": False,
"strategy": DDPStrategy(find_unused_parameters=False)
if cfg.strategy == "ddp"
else cfg.strategy,
}
)
trainer = Trainer(**trainer_kwargs)

# fix for incompatibility with nvidia-dali and pytorch lightning
# with dali 1.15 (this will be fixed on 1.16)
Expand All @@ -229,7 +228,7 @@ def prefetch_batches(self) -> int:
except:
pass

if args.data_format == "dali":
if cfg.data.format == "dali":
trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
Expand Down
Loading

0 comments on commit 6f228fd

Please sign in to comment.