Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
simplify SK code + large models on hubconf + remove opencv blur
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed Apr 2, 2021
1 parent 1018366 commit 9a2dc80
Show file tree
Hide file tree
Showing 17 changed files with 100 additions and 92 deletions.
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Our method can be trained with large and small batches and can scale to unlimite

# Model Zoo

We release our best ResNet-50 pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone.
To load the model, simply do:
We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone.
To load our best SwAV pre-trained ResNet-50 model, simply do:
```python
import torch
model = torch.hub.load('facebookresearch/swav', 'resnet50')
Expand All @@ -41,6 +41,14 @@ We also provide models pre-trained with [DeepCluster-v2](./main_deepclusterv2.py

## Larger architectures
We provide SwAV models with ResNet-50 networks where we multiply the width by a factor ×2, ×4, and ×5.
To load the corresponding backbone you can use:
```python
import torch
rn50w2 = torch.hub.load('facebookresearch/swav', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav', 'resnet50w5')
```

| network | parameters | epochs | ImageNet top-1 acc. | url | args |
|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|
| RN50-w2 | 94M | 400 | 77.3 | [model](https://dl.fbaipublicfiles.com/deepcluster/swav_RN50w2_400ep_pretrain.pth.tar) | [script](./scripts/swav_RN50w2_400ep_pretrain.sh) |
Expand All @@ -63,7 +71,7 @@ We provide the running times for some of our runs:
- torchvision
- CUDA 10.1
- [Apex](https://github.com/NVIDIA/apex) with CUDA extension (see [how I installed apex](https://github.com/facebookresearch/swav/issues/18#issuecomment-748123838))
- Other dependencies: opencv-python, scipy, pandas, numpy
- Other dependencies: scipy, pandas, numpy

## Singlenode training
SwAV is very simple to implement and experiment with.
Expand Down Expand Up @@ -160,9 +168,8 @@ We now analyze the collapsing problem: it happens when all examples are mapped t
In other words, the convnet always has the same output regardless of its input, it is a constant function.
All examples gets the same cluster assignment because they are identical, and the only valid assignment that satisfy the equipartition constraint in this case is the uniform assignment (1/K where K is the number of prototypes).
In turn, this uniform assignment is trivial to predict since it is the same for all examples.
Reducing epsilon parameter (see Eq(3) of our [paper](https://arxiv.org/abs/2006.09882)) encourages the assignments `Q` to be less uniform (more peaked), which strongly helps avoiding collapsing.
However, using a too low value for epsilon leads to numerical instability.
If the loss goes to NaN, you can try using `--improve_numerical_stability true` in `main_swav.py`.
Reducing epsilon parameter (see Eq(3) of our [paper](https://arxiv.org/abs/2006.09882)) encourages the assignments `Q` to be sharper (i.e. less uniform), which strongly helps avoiding collapse.
However, using a too low value for epsilon may lead to numerical instability.

#### Training gets unstable when using the queue.
The queue is composed of feature representations from the previous batches.
Expand All @@ -181,11 +188,6 @@ We observe that it made the loss go more down.
If when introducing the queue, the loss goes up and does not decrease afterwards you should stop your training and change the queue parameters.
We recommend (i) using a smaller queue, (ii) starting the queue later in training.

#### Slow training.
If you experiment slow running times, it might be because of the Gaussian blur.
We recommend operating Gaussian blur with PIL library instead of open-cv (pass `--use_pil_blur true` as argument).
Note that we keep `--use_pil_blur false` in the [scripts](./scripts) because all our experiments were performed with open-cv but we strongly recommend using PIL library instead.

## License
See the [LICENSE](LICENSE) file for more details.

Expand Down
55 changes: 55 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import torch
from torchvision.models.resnet import resnet50 as _resnet50

from src.resnet50 import resnet50w2 as _resnet50w2
from src.resnet50 import resnet50w4 as _resnet50w4
from src.resnet50 import resnet50w5 as _resnet50w5

dependencies = ["torch", "torchvision"]


Expand All @@ -29,3 +33,54 @@ def resnet50(pretrained=True, **kwargs):
# load weights
model.load_state_dict(state_dict, strict=False)
return model


def resnet50w2(pretrained=True, **kwargs):
"""
ResNet-50-w2 pre-trained with SwAV.
"""
model = _resnet50w2(**kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deepcluster/swav_RN50w2_400ep_pretrain.pth.tar",
map_location="cpu",
)
# removes "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# load weights
model.load_state_dict(state_dict, strict=False)
return model


def resnet50w4(pretrained=True, **kwargs):
"""
ResNet-50-w4 pre-trained with SwAV.
"""
model = _resnet50w4(**kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deepcluster/swav_RN50w4_400ep_pretrain.pth.tar",
map_location="cpu",
)
# removes "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# load weights
model.load_state_dict(state_dict, strict=False)
return model


def resnet50w5(pretrained=True, **kwargs):
"""
ResNet-50-w5 pre-trained with SwAV.
"""
model = _resnet50w5(**kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deepcluster/swav_RN50w5_400ep_pretrain.pth.tar",
map_location="cpu",
)
# removes "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# load weights
model.load_state_dict(state_dict, strict=False)
return model
3 changes: 0 additions & 3 deletions main_deepclusterv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
help="argument in RandomResizedCrop (example: [0.14, 0.05])")
parser.add_argument("--max_scale_crops", type=float, default=[1], nargs="+",
help="argument in RandomResizedCrop (example: [1., 0.14])")
parser.add_argument("--use_pil_blur", type=bool_flag, default=True,
help="""use PIL library to perform blur instead of opencv""")

#########################
## dcv2 specific params #
Expand Down Expand Up @@ -128,7 +126,6 @@ def main():
args.min_scale_crops,
args.max_scale_crops,
return_index=True,
pil_blur=args.use_pil_blur,
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
Expand Down
83 changes: 31 additions & 52 deletions main_swav.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
Expand Down Expand Up @@ -49,8 +50,6 @@
help="argument in RandomResizedCrop (example: [0.14, 0.05])")
parser.add_argument("--max_scale_crops", type=float, default=[1], nargs="+",
help="argument in RandomResizedCrop (example: [1., 0.14])")
parser.add_argument("--use_pil_blur", type=bool_flag, default=True,
help="""use PIL library to perform blur instead of opencv""")

#########################
## swav specific params #
Expand All @@ -61,8 +60,6 @@
help="temperature parameter in training loss")
parser.add_argument("--epsilon", default=0.05, type=float,
help="regularization parameter for Sinkhorn-Knopp algorithm")
parser.add_argument("--improve_numerical_stability", default=False, type=bool_flag,
help="improves numerical stability in Sinkhorn-Knopp algorithm")
parser.add_argument("--sinkhorn_iterations", default=3, type=int,
help="number of iterations in Sinkhorn-Knopp algorithm")
parser.add_argument("--feat_dim", default=128, type=int,
Expand Down Expand Up @@ -137,7 +134,6 @@ def main():
args.nmb_crops,
args.min_scale_crops,
args.max_scale_crops,
pil_blur=args.use_pil_blur,
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -266,7 +262,6 @@ def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
data_time = AverageMeter()
losses = AverageMeter()

softmax = nn.Softmax(dim=1).cuda()
model.train()
use_the_queue = False

Expand Down Expand Up @@ -295,7 +290,7 @@ def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
loss = 0
for i, crop_id in enumerate(args.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id: bs * (crop_id + 1)]
out = output[bs * crop_id: bs * (crop_id + 1)].detach()

# time to use the queue
if queue is not None:
Expand All @@ -308,20 +303,15 @@ def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
# fill the queue
queue[i, bs:] = queue[i, :-bs].clone()
queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs]

# get assignments
q = out / args.epsilon
if args.improve_numerical_stability:
M = torch.max(q)
dist.all_reduce(M, op=dist.ReduceOp.MAX)
q -= M
q = torch.exp(q).t()
q = distributed_sinkhorn(q, args.sinkhorn_iterations)[-bs:]
q = distributed_sinkhorn(out)[-bs:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id):
p = softmax(output[bs * v: bs * (v + 1)] / args.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
x = output[bs * v: bs * (v + 1)] / args.temperature
subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(args.nmb_crops) - 1)
loss /= len(args.crops_for_assign)

Expand All @@ -332,7 +322,7 @@ def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
scaled_loss.backward()
else:
loss.backward()
# cancel some gradients
# cancel gradients for the prototypes
if iteration < args.freeze_prototypes_niters:
for name, p in model.named_parameters():
if "prototypes" in name:
Expand Down Expand Up @@ -361,41 +351,30 @@ def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
return (epoch, losses.avg), queue


def distributed_sinkhorn(Q, nmb_iters):
with torch.no_grad():
Q = shoot_infs(Q)
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (args.world_size * Q.shape[1])
for it in range(nmb_iters):
u = torch.sum(Q, dim=1)
dist.all_reduce(u)
u = r / u
u = shoot_infs(u)
Q *= u.unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()


def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
@torch.no_grad()
def distributed_sinkhorn(out):
Q = torch.exp(out / args.epsilon).t() # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * args.world_size # number of samples to assign
K = Q.shape[0] # how many prototypes

# make the matrix sums to 1
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

for it in range(args.sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
dist.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K

# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B

Q *= B # the colomns must sum to 1 so that Q is an assignment
return Q.t()


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion scripts/deepclusterv2_400ep_2x224_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.08 \
--max_scale_crops 1. \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--feat_dim 128 \
--nmb_prototypes 3000 3000 3000 \
Expand Down
1 change: 0 additions & 1 deletion scripts/deepclusterv2_400ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.08 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--feat_dim 128 \
--nmb_prototypes 3000 3000 3000 \
Expand Down
1 change: 0 additions & 1 deletion scripts/deepclusterv2_800ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--feat_dim 128 \
--nmb_prototypes 3000 3000 3000 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_100ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_200ep_bs256_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_200ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_400ep_2x224_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.08 \
--max_scale_crops 1. \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_400ep_bs256_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_400ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_800ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_RN50w2_400ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
1 change: 0 additions & 1 deletion scripts/swav_RN50w4_400ep_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --labe
--min_scale_crops 0.08 0.08 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--use_pil_blur false \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
Expand Down
Loading

0 comments on commit 9a2dc80

Please sign in to comment.