Skip to content

[RFC] API For Common Layers In Torchvision #4333

Open
@oke-aditya

Description

@oke-aditya

Picks up from discussion in #4293 (comment)

🚀 Feature

API for Commonly used Layers in building models.

Motivation

A huge code duplication is involved in building very basic blocks for large neural networks. Some of these blocks can be standardized and re-used. Also these could be offered to end user as an API so that downstream libraries can build models easily.

E.g. for duplication are SqueezeExcitation, ConvBNRelu, etc.

Pitch

Create an API called torchvision.nn or torchvision.layers
Our implementations need to be generic but not locked to certain activation functions or channels, etc.
These can be simply classes based on nn.Module.

An example

class ConvNormAct(nn.Module):
    def __init__(
        self,
        in_planes: int,
        out_planes: int,
        kernel_size: int = 3,
        stride: int = 1,
        groups: int = 1,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None: 

        super().__init__()
        #  Thsi could be even a nn.Sequential Block instead.
        padding = (kernel_size - 1) // 2 * dilation
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups, bias=False)
        self.norm_layer = norm_layer
        self.activation = activation_layer

    def forward(x: Tensor):
         x = self.conv1(x)
         if self.norm_layer is not None:
              x = self.norm_layer(x)
         x = self.activation(x)
        return x

User can use this as

from torchvision import layers
l1 = layers.ConvNormAct(3, 10, norm_layer=nn.BatchNorm2d(), activation_layer=nn.Relu(inplace=True))
dummy_input = torch.randn(3, 224, 224)
out = l1(dummy_input)

Also layers can be mixed into new custom models.

E.g.

class SimpleModel(nn.Module):
    def __init__():
        self.l1 = layers.ConvNormAct(3, 10, activation_layer=nn.Relu(inplace=True))
        self.l2  = layers.ConvNormAct(10, 3, activation_layer=nn.Relu(inplace=True))

    def forward(x: Tensor):
        x = self.l1(x)
        x = self.l2(x)
        return x

Points to Consider

We have torchvision.ops then why layers?

Ops are transforms that do manipulations with pre-processing and post processing of structures such as Boxes, Masks, Anchors, etc. These are not used in "model building" but are optional steps for specific models.
Also these are

E.g. NMS, IoU, RoI, etc.

One doesn't need ops for every model.
E.g. You don't need to do RoI align, for DeTR. Or you don't computer IoU for segmentation masks.

With separate API can be clear distinction in what are layer for models and operators for tasks such as detection, segmentation.

Should torchvision.nn contain losses?

This is tricky, and for now I see no clear winner.
PyTorch does not differentiate the API for losses or layers.
E.g. we do nn.Conv2d which builds a convolutional layer. Also we do nn.CrossEntropy or nn.MSE which builds a loss function.

I'm not sure whether layers should be torchvision.layers or torchvision.nn (if implemented of course)

Users don't need to worry about colliding namespaces. They can do.

from torch import nn
from torchvision import nn as tnn

Note that nn seems to be the convention adopted by torchtext.

Other points to consider.

  1. Portability: -
    Currently most of the torchvision models are easily copy pastable. E.g. We can easily copy paste mobilenetv2.py file and edit it on the go to customize models.
    By Adding such API we can reduce the internal code duplication but these files would no longer be single standalone files for models.

  2. Layer Customization : -
    Layer Customization has far too many options to consider.
    E.g. there are several implementations possible for BasicBlock of ResNet or some slight modifications of inverted residual layer.
    One can't create an implementation that will suit all the needs for everyone. If one tries to, then the API would be significantly complicated.

  3. TorchScript: -
    We shouldn't be hampering torchscript compatibility of any model while implementing above API.

Additional context

Some candidates for layers

  1. ConvNormAct
  2. LinearDropoutAct
  3. SqueezeExcitation
  4. StochasticDepth
  5. BasicBlock
  6. MLP Simple Multi Layer Perceptron, often duplicated in downstream library codebases E.g. Detr
  7. FrozenBatchNorm2d Also used in DETR.
    Not sure why it is under ops, it doesn't it there

Also Quantizable versions of these !

  1. QuantizableConvNormAct
  2. QuantizableLinearDropoutAct

Quantizable versions will allow users to directly fuse up the models.

Additionally I will recommend not hurrying this feature, we could create torchvision.experimental.nn and start working out things.(or probably I can try in a fork)
Linking plans #3911 #4187

P.S. I'm a junior developer and all my thoughts are probably step too far. So please forgive me if I'm wrong.

cc @datumbox @pmeier @NicolasHug @fmassa

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions