diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 5dd1762727a..72a0a338852 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -37,29 +37,35 @@ def forward(self, x: Tensor) -> Tensor: return x +class Permute(nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.permute(x, self.dims) + + class CNBlock(nn.Module): def __init__( - self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] + self, + dim, + layer_scale: float, + stochastic_depth_prob: float, + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.block = nn.Sequential( - ConvNormActivation( - dim, - dim, - kernel_size=7, - groups=dim, - norm_layer=norm_layer, - activation_layer=None, - bias=True, - ), - ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), - ConvNormActivation( - 4 * dim, - dim, - kernel_size=1, - norm_layer=None, - activation_layer=None, - ), + nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), + Permute([0, 2, 3, 1]), + norm_layer(dim), + nn.Linear(in_features=dim, out_features=4 * dim, bias=True), + nn.GELU(), + nn.Linear(in_features=4 * dim, out_features=dim, bias=True), + Permute([0, 3, 1, 2]), ) self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") @@ -142,7 +148,7 @@ def __init__( for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage.append(block(cnf.input_channels, layer_scale, sd_prob)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: @@ -213,7 +219,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", + url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ **_COMMON_META, @@ -227,7 +233,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", + url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=230), meta={ **_COMMON_META, @@ -241,7 +247,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", + url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, @@ -255,7 +261,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", + url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META,