Skip to content

Commit

Permalink
Merge branch 'main' into models/convnext_variants
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox authored Jan 31, 2022
2 parents 1ab9030 + ac1f0ff commit f803797
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 38 deletions.
Binary file added test/expect/ModelTester.test_vitc_b_16_expect.pkl
Binary file not shown.
30 changes: 30 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any

import pytest
import torch
Expand Down Expand Up @@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr():
assert t.__repr__() == expected_string


test_vit_conv_stem_configs = [
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
]


def vitc_b_16(**kwargs: Any):
return models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
conv_stem_configs=test_vit_conv_stem_configs,
**kwargs,
)


@pytest.mark.parametrize("model_fn", [vitc_b_16])
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev):
Expand Down
37 changes: 25 additions & 12 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)


def test_flow_to_image():
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2
flow[1] -= w / 2

if batch:
flow = torch.stack([flow, flow])

img = utils.flow_to_image(flow)
assert img.shape == (2, 3, h, w) if batch else (3, h, w)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu")
assert_equal(expected_img, img)

if batch:
expected_img = torch.stack([expected_img, expected_img])

assert_equal(expected_img, img)

def test_flow_to_image_errors():
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)

with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow1)
with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow2)
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
utils.flow_to_image(flow=wrong_flow3)
@pytest.mark.parametrize(
"input_flow, match",
(
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
),
)
def test_flow_to_image_errors(input_flow, match):
with pytest.raises(ValueError, match=match):
utils.flow_to_image(flow=input_flow)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions torchvision/datasets/hmdb51.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class HMDB51(VisionDataset):
"""
`HMDB51 <http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
`HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
dataset.
HMDB51 is an action recognition video dataset.
Expand Down Expand Up @@ -47,9 +47,9 @@ class HMDB51(VisionDataset):
- label (int): class of the video clip
"""

data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
splits = {
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
}
TRAIN_TAG = 1
Expand Down
56 changes: 47 additions & 9 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Callable, List, NamedTuple, Optional

import torch
import torch.nn as nn

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..utils import _log_api_usage_once

__all__ = [
Expand All @@ -25,6 +26,14 @@
}


class ConvStemConfig(NamedTuple):
out_channels: int
kernel_size: int
stride: int
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(nn.Sequential):
"""Transformer MLP block."""

Expand Down Expand Up @@ -134,6 +143,7 @@ def __init__(
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
super().__init__()
_log_api_usage_once(self)
Expand All @@ -148,11 +158,31 @@ def __init__(
self.representation_size = representation_size
self.norm_layer = norm_layer

input_channels = 3

# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
seq_proj = nn.Sequential()
prev_channels = 3
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
ConvNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)

seq_length = (image_size // patch_size) ** 2

Expand Down Expand Up @@ -184,9 +214,17 @@ def __init__(
self._init_weights()

def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
else:
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
nn.init.zeros_(self.conv_proj.conv_last.bias)

if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
Expand Down
37 changes: 23 additions & 14 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""

if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")

if flow.ndim != 3 or flow.size(0) != 2:
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.")
orig_shape = flow.shape
if flow.ndim == 3:
flow = flow[None] # Add batch dim

max_norm = torch.sum(flow ** 2, dim=0).sqrt().max()
if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")

max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
return _normalized_flow_to_image(normalized_flow)
img = _normalized_flow_to_image(normalized_flow)

if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img


@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:

"""
Converts a normalized flow to an RGB image.
Converts a batch of normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
"""

_, H, W = normalized_flow.shape
flow_image = torch.zeros((3, H, W), dtype=torch.uint8)
N, _, H, W = normalized_flow.shape
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt()
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
Expand All @@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[c, :, :] = torch.floor(255 * col)
flow_image[:, c, :, :] = torch.floor(255 * col)
return flow_image


Expand Down

0 comments on commit f803797

Please sign in to comment.