diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index f7dabe1ec1..44c9885fa0 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -10,7 +10,7 @@ from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library -from .utils import get_conv1d_output_size +from .utils import get_conv1d_output_size, get_conv2d_output_size lib = Library("cadence", "DEF") @@ -122,7 +122,7 @@ def quantized_conv_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, channel_last: bool = False, -): +) -> torch.Tensor: out_channels, _in_channels, *kernel_size = weight.shape in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 @@ -130,8 +130,14 @@ def quantized_conv_meta( assert len(in_size) < 6 # Compute the output tensor size - output_size = get_conv1d_output_size( - in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0] + output_size = ( + get_conv1d_output_size( + in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0] + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, channel_last + ) ) return input.new_empty(output_size, dtype=input.dtype) diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index 0ef5c50e5b..90ba68e538 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -6,7 +6,7 @@ import logging import operator -from typing import Dict +from typing import Dict, List, Tuple import torch from executorch.exir import memory @@ -49,6 +49,33 @@ def get_conv1d_output_size( return torch.Size((in_size[0], out_channels, lout)) +# Get the output size of a 2D convolution given the input size and parameters +def get_conv2d_output_size( + in_size: torch.Size, + out_channels: int, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + kernel_size: List[int], + channel_last: bool, +) -> torch.Size: + assert len(in_size) == 4 + if channel_last: + N, H, W, C = in_size + else: + N, C, H, W = in_size + + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[ + 0 + ] + 1 + wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ + 1 + ] + 1 + + return torch.Size((in_size[0], out_channels, hout, wout)) + + # Return the overload packet for the edge op def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: edge_op_namespace, edge_op_name = ( diff --git a/examples/cadence/models/vision_transformer.py b/examples/cadence/models/vision_transformer.py new file mode 100644 index 0000000000..79c9e3f196 --- /dev/null +++ b/examples/cadence/models/vision_transformer.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch +import torchvision + +from executorch.backends.cadence.aot.export_example import export_model + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + model = torchvision.models.vit_b_16() + example_inputs = (torch.randn(1, 3, 224, 224),) + + export_model(model, example_inputs)