Skip to content

MaxVit model #6342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f15fd92
Added maxvit architecture and tests
TeodorPoncu Aug 1, 2022
c5b2839
rebased + addresed comments
TeodorPoncu Aug 5, 2022
5e8a222
Revert "rebased + addresed comments"
TeodorPoncu Aug 5, 2022
aa95139
Re-added model changes after revert
TeodorPoncu Aug 5, 2022
1fddecc
aligned with partial original implementation
TeodorPoncu Sep 14, 2022
b7f0e97
removed submitit script fixed lint
TeodorPoncu Sep 14, 2022
872f40f
mypy fix for too many arguments
TeodorPoncu Sep 14, 2022
f561edf
updated old tests
TeodorPoncu Sep 14, 2022
314b82a
removed per batch lr scheduler and seed setting
TeodorPoncu Sep 16, 2022
a4863e9
removed ontap
TeodorPoncu Sep 16, 2022
c4406e4
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 16, 2022
2111680
added docs, validated weights
TeodorPoncu Sep 16, 2022
cc51c2b
fixed test expect, moved shape assertions in the begging for torch.fx…
TeodorPoncu Sep 17, 2022
d2dfe71
mypy fix
TeodorPoncu Sep 17, 2022
328f9b6
lint fix
TeodorPoncu Sep 18, 2022
b334b7f
added legacy interface
TeodorPoncu Sep 18, 2022
ebb8c16
added weight link
TeodorPoncu Sep 20, 2022
e281371
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 20, 2022
20422bc
updated docs
TeodorPoncu Sep 21, 2022
9ad86fe
Merge branch 'BATERIES]-add-max-vit' of https://github.com/pytorch/vi…
TeodorPoncu Sep 21, 2022
775990c
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 21, 2022
a24e549
Update references/classification/train.py
TeodorPoncu Sep 21, 2022
bb42548
Update torchvision/models/maxvit.py
TeodorPoncu Sep 21, 2022
ed21d3d
adressed comments
TeodorPoncu Sep 21, 2022
09e4ced
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 22, 2022
521d6d5
update ra_maginuted and augmix_severity default values
TeodorPoncu Sep 22, 2022
79cb004
Merge branch 'BATERIES]-add-max-vit' of https://github.com/pytorch/vi…
TeodorPoncu Sep 22, 2022
97cbcd8
adressed some comments
TeodorPoncu Sep 22, 2022
9fc6a5b
Merge branch 'BATERIES]-add-max-vit' of https://github.com/pytorch/vi…
TeodorPoncu Sep 22, 2022
6b00ca8
remove input_channels parameter
TeodorPoncu Sep 23, 2022
45d3966
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 23, 2022
2aca920
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 23, 2022
cab35c1
Merge branch 'main' into BATERIES]-add-max-vit
TeodorPoncu Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ def __init__(
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
policy_magnitude=9,
random_erase_prob=0.0,
center_crop=False,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if center_crop else [transforms.CenterCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=policy_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
Expand Down
122 changes: 122 additions & 0 deletions references/classification/run_with_submitit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import os
import uuid
from pathlib import Path

import train
import submitit


def parse_args():
train_parser = train.get_args_parser(add_help=False)
parser = argparse.ArgumentParser("Submitit for train", parents=[train_parser], add_help=True)
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
parser.add_argument("--timeout", default=60*24*30, type=int, help="Duration of the job")
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
parser.add_argument("--partition", default="train", type=str, help="the partition (default train).")
return parser.parse_args()


def get_shared_folder() -> Path:
user = os.getenv("USER")
path = "/data/checkpoints"
if Path(path).is_dir():
p = Path(f"{path}/{user}/experiments")
p.mkdir(exist_ok=True)
return p
raise RuntimeError("No shared folder available")


def get_init_file_folder() -> Path:
user = os.getenv("USER")
path = "/shared"
if Path(path).is_dir():
p = Path(f"{path}/{user}")
p.mkdir(exist_ok=True)
return p
raise RuntimeError("No shared folder available")


def get_init_file():
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_init_file_folder()), exist_ok=True)
init_file = get_init_file_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file


class Trainer(object):
def __init__(self, args):
self.args = args

def __call__(self):
import train

self._setup_gpu_args()
train.main(self.args)

def checkpoint(self):
import os
import submitit
from pathlib import Path

self.args.dist_url = get_init_file().as_uri()
checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
if os.path.exists(checkpoint_file):
self.args.resume = checkpoint_file
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)

def _setup_gpu_args(self):
import submitit
from pathlib import Path

job_env = submitit.JobEnvironment()
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")


def main():
args = parse_args()
if args.job_dir == "":
args.job_dir = get_shared_folder() / "%j"

# Note that the folder will depend on the job_id, to easily track experiments
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=300)

# cluster setup is defined by environment variables
num_gpus_per_node = args.ngpus
nodes = args.nodes
timeout_min = args.timeout

executor.update_parameters(
#mem_gb=96 * num_gpus_per_node, # 768GB per machine
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=12, # 96 cpus per machine
nodes=nodes,
timeout_min=timeout_min, # max is 60 * 72
slurm_partition=args.partition,
slurm_signal_delay_s=120,
)


executor.update_parameters(name="torchvision")

args.dist_url = get_init_file().as_uri()
args.output_dir = args.job_dir

trainer = Trainer(args)
job = executor.submit(trainer)

print("Submitted job_id:", job.job_id)


if __name__ == "__main__":
main()
47 changes: 36 additions & 11 deletions references/classification/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
import random
import time
import warnings

Expand All @@ -15,7 +16,7 @@
from torchvision.transforms.functional import InterpolationMode


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, scheduler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand Down Expand Up @@ -43,6 +44,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()

if scheduler is not None and args.lr_step_every_batch:
scheduler.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
Expand Down Expand Up @@ -113,7 +117,7 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
val_resize_size, val_crop_size, train_crop_size, center_crop, policy_magnitude = args.val_resize_size, args.val_crop_size, args.train_crop_size, args.train_center_crop, args.policy_magnitude
interpolation = InterpolationMode(args.interpolation)

print("Loading training data")
Expand All @@ -129,10 +133,12 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
center_crop=center_crop,
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
policy_magnitude=policy_magnitude,
),
)
if args.cache_dataset:
Expand Down Expand Up @@ -182,7 +188,12 @@ def load_data(traindir, valdir, args):
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)


if args.seed is None:
# randomly choose a seed
args.seed = random.randint(0, 2 ** 32)
utils.set_seed(args.seed)

utils.init_distributed_mode(args)
print(args)

Expand Down Expand Up @@ -261,13 +272,21 @@ def main(args):
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

batches_per_epoch = len(data_loader)
warmup_iters = args.lr_warmup_epochs
total_iters = args.epochs

if args.lr_step_every_batch:
warmup_iters *= batches_per_epoch
total_iters *= batches_per_epoch

args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
optimizer, T_max=total_iters - warmup_iters, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
Expand All @@ -280,18 +299,18 @@ def main(args):
if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler
Expand Down Expand Up @@ -341,8 +360,9 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, lr_scheduler)
if not args.lr_step_every_batch:
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
Expand Down Expand Up @@ -371,7 +391,7 @@ def get_args_parser(add_help=True):

parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--data-path", default="/datasets01_ontap/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -425,6 +445,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--lr-step-every-batch", action="store_true", help="decrease lr every step-size batches", default=False)
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
Expand All @@ -448,6 +469,7 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--policy-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

# Mixed precision training parameters
Expand Down Expand Up @@ -486,13 +508,16 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument(
"--train-center-crop", action="store_true", help="use center crop instead of random crop for training (default: False)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

parser.add_argument("--seed", default=None, type=int, help="the seed for randomness (default: None). A `None` value means a seed will be randomly generated")
return parser


Expand Down
12 changes: 12 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.distributed as dist
import numpy as np
import random


class SmoothedValue:
Expand Down Expand Up @@ -463,3 +465,13 @@ def _add_params(module, prefix=""):
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups

def set_seed(seed: int):
"""
Function for setting all the RNGs to the same seed
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
Binary file added test/expect/ModelTester.test_maxvit_t_expect.pkl
Binary file not shown.
39 changes: 39 additions & 0 deletions test/test_architecture_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest

import pytest
import torch

from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition


class MaxvitTester(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that here you are testing specific layers from MaxViT. This is not something we did previously, so perhaps it does need to be on a separate file.

@YosuaMichael any thoughts here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for a pretty late response!
Currently I dont have any opinion how we should test specific layer of the model and I think this is okay. (Need more time to think and discuss whether we should do more of this kind of test or not)

def test_maxvit_window_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7

x = torch.randn(input_shape)

partition = WindowPartition(partition_size=7)
departition = WindowDepartition(partition_size=partition_size, n_partitions=(input_shape[3] // partition_size))

assert torch.allclose(x, departition(partition(x)))

def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7

x = torch.randn(input_shape)
partition = torch.nn.Sequential(
WindowPartition(partition_size=(input_shape[3] // partition_size)),
SwapAxes(-2, -3),
)
departition = torch.nn.Sequential(
SwapAxes(-2, -3),
WindowDepartition(partition_size=(input_shape[3] // partition_size), n_partitions=partition_size),
)

assert torch.allclose(x, departition(partition(x)))


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_weights, get_weight, list_models
Loading