Skip to content

Commit

Permalink
Optimize speed of CNBlock.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 31, 2022
1 parent 2bbb112 commit 290440b
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,35 @@ def forward(self, x: Tensor) -> Tensor:
return x


class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims

def forward(self, x):
return torch.permute(x, self.dims)


class CNBlock(nn.Module):
def __init__(
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)

self.block = nn.Sequential(
ConvNormActivation(
dim,
dim,
kernel_size=7,
groups=dim,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
),
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
ConvNormActivation(
4 * dim,
dim,
kernel_size=1,
norm_layer=None,
activation_layer=None,
),
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
Expand Down Expand Up @@ -142,7 +148,7 @@ def __init__(
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
Expand Down Expand Up @@ -213,7 +219,7 @@ def _convnext(

class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", # TODO: repackage
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
Expand All @@ -227,7 +233,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):

class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth",
url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", # TODO: repackage
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
Expand All @@ -241,7 +247,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):

class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth",
url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", # TODO: repackage
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
Expand All @@ -255,7 +261,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):

class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth",
url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", # TODO: repackage
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
Expand Down

0 comments on commit 290440b

Please sign in to comment.