Skip to content

Commit

Permalink
Update OSS repo (#2033)
Browse files Browse the repository at this point in the history
Summary:

Update the OSS Xtensa repo with more up to date compiler and quantizer things. Introduce a test folder and a conv1d test.

Reviewed By: cccclai

Differential Revision: D54034581
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Feb 21, 2024
1 parent 20714e7 commit d3b8584
Show file tree
Hide file tree
Showing 9 changed files with 653 additions and 102 deletions.
25 changes: 3 additions & 22 deletions examples/xtensa/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,7 @@
logging.basicConfig(level=logging.INFO, format=FORMAT)


if __name__ == "__main__":
in_features = 32
out_features = 16
bias = True
shape = [64, in_features]

class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool):
super().__init__()
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)

def forward(self, x: torch.Tensor):
output_linear_out = self.output_linear(x)
return output_linear_out

model = QuantizedLinear(in_features, out_features, bias)
model.eval()

example_inputs = (torch.ones(shape),)

def export_xtensa_model(model, example_inputs):
# Quantizer
quantizer = XtensaQuantizer()

Expand Down Expand Up @@ -77,14 +58,14 @@ def forward(self, x: torch.Tensor):
export_to_edge(
converted_model_exp,
example_inputs,
EdgeCompileConfig(
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()]
)
.to_executorch(config=ExecutorchBackendConfig(extract_constant_segment=False))
.to_executorch(config=ExecutorchBackendConfig())
)

logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
Expand Down
58 changes: 49 additions & 9 deletions examples/xtensa/aot/meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from .utils import get_conv1d_output_size
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

Expand All @@ -25,10 +28,17 @@
)

lib.define(
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
)
lib.define(
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
)
lib.define(
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
)

m = Library("xtensa", "IMPL", "Meta")
Expand Down Expand Up @@ -58,17 +68,15 @@ def dequantize_per_tensor_meta(
return input.new_empty(input.size(), dtype=torch.float)


@impl(m, "quantized_linear_pt2")
def quantized_linear_pt2_meta(
@impl(m, "quantized_linear")
def quantized_linear_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_scale: float,
in_zero_point: int,
weight_scale: float,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
weight_zero_point: torch.Tensor,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_zero_point: int,
):
# src comes in shape [leading_dims, in_dim]
Expand All @@ -79,3 +87,35 @@ def quantized_linear_pt2_meta(
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=torch.uint8)


@impl(m, "quantized_conv")
def quantized_conv_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
channel_last: bool = False,
):
out_channels, _in_channels, *kernel_size = weight.shape
in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = get_conv1d_output_size(
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
)

return input.new_empty(output_size, dtype=input.dtype)
Loading

0 comments on commit d3b8584

Please sign in to comment.