Skip to content

Commit

Permalink
Add vision transformer (#4077)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4077

As titled. It will be useful for a few Cadence teams to be able to look at least at the AoT graph.

Reviewed By: dulinriley

Differential Revision: D59097944

fbshipit-source-id: ed94d5b1adfb9d26845062449127cee946e29092
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jun 28, 2024
1 parent 748a4f8 commit 3eec95a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
14 changes: 10 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

from .utils import get_conv1d_output_size
from .utils import get_conv1d_output_size, get_conv2d_output_size

lib = Library("cadence", "DEF")

Expand Down Expand Up @@ -122,16 +122,22 @@ def quantized_conv_meta(
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
channel_last: bool = False,
):
) -> torch.Tensor:
out_channels, _in_channels, *kernel_size = weight.shape
in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = get_conv1d_output_size(
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
output_size = (
get_conv1d_output_size(
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
)
)

return input.new_empty(output_size, dtype=input.dtype)
Expand Down
29 changes: 28 additions & 1 deletion backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
import operator
from typing import Dict
from typing import Dict, List, Tuple

import torch
from executorch.exir import memory
Expand Down Expand Up @@ -49,6 +49,33 @@ def get_conv1d_output_size(
return torch.Size((in_size[0], out_channels, lout))


# Get the output size of a 2D convolution given the input size and parameters
def get_conv2d_output_size(
in_size: torch.Size,
out_channels: int,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
kernel_size: List[int],
channel_last: bool,
) -> torch.Size:
assert len(in_size) == 4
if channel_last:
N, H, W, C = in_size
else:
N, C, H, W = in_size

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
0
] + 1
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
1
] + 1

return torch.Size((in_size[0], out_channels, hout, wout))


# Return the overload packet for the edge op
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
edge_op_namespace, edge_op_name = (
Expand Down
28 changes: 28 additions & 0 deletions examples/cadence/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Example script for exporting simple models to flatbuffer

import logging

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch
import torchvision

from executorch.backends.cadence.aot.export_example import export_model


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


if __name__ == "__main__":

model = torchvision.models.vit_b_16()
example_inputs = (torch.randn(1, 3, 224, 224),)

export_model(model, example_inputs)

0 comments on commit 3eec95a

Please sign in to comment.