Skip to content

Commit

Permalink
Update OSS repo (#2033)
Browse files Browse the repository at this point in the history
Summary:

Update the OSS Xtensa repo with more up to date compiler and quantizer things. Introduce a test folder and a conv1d test.

Reviewed By: cccclai

Differential Revision: D54034581
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Mar 29, 2024
1 parent f799c0e commit 9653c8c
Show file tree
Hide file tree
Showing 17 changed files with 1,385 additions and 140 deletions.
13 changes: 8 additions & 5 deletions docs/source/build-run-xtensa.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ examples/xtensa/
├── aot
├── kernels
├── ops
├── tests
├── third-party
└── utils
```

***AoT (Ahead-of-Time) Components***:

The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) defines a model and some example inputs (set to a vector of ones), and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.

***Operators***:

Expand All @@ -99,13 +100,15 @@ python3 -m examples.portable.scripts.export --model_name="add"

***Quantized Linear***:

The second, more complex model is a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py#L88). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
The other, more complex model are custom operators, including:
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.

The generated file is called `XtensaDemoModel.pte`.
In both cases the generated file is called `XtensaDemoModel.pte`.

```bash
cd executorch
python3 -m examples.xtensa.aot.export_example
python3 -m examples.xtensa.tests.quantized_<linear,conv1d>_example
```

### Runtime
Expand Down Expand Up @@ -196,6 +199,6 @@ First 20 elements of output 0

In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.

The model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model in [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).

Other models can be created following the same structure, always assuming that operators and kernels are available.
70 changes: 70 additions & 0 deletions examples/xtensa/aot/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable

import torch

from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge

from torch.export import export
from torch.export.exported_program import ExportedProgram


def export_program(
model: Callable,
inputs: Any,
pt2_quant: bool = False,
) -> ExportedProgram:
# we don't support training mode. Make it eval
if hasattr(model, "eval"):
if pt2_quant:
# pyre-fixme[6]: Incompatible parameter type.
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
model.eval()

# if it's already an ExportedProgram, just return it
if isinstance(model, ExportedProgram):
return model

assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)

# else: capture the model and return it.
return export(model, inputs)


# Export the model and lower it it edge IR.
def export_to_edge(
model: Callable,
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> EdgeProgramManager:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)

if dump_graphs:
logging.info(
f"Exported graph:\n{expo_program.graph_module.graph}"
)

# Call to_edge to convert the graph to edge IR.
edge_prog_manager = to_edge(
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)

if dump_graphs:
logging.info(
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
)

return edge_prog_manager
57 changes: 13 additions & 44 deletions examples/xtensa/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,28 @@

from .meta_registrations import * # noqa

import torch
from executorch.exir import EdgeCompileConfig
from executorch.exir import ExecutorchBackendConfig
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from ...portable.utils import export_to_edge, save_pte_program
from ...portable.utils import save_pte_program

from .compiler import export_to_edge
from .quantizer import (
QuantFusion,
ReplacePT2DequantWithXtensaDequant,
ReplacePT2QuantWithXtensaQuant,
XtensaQuantizer,
XtensaBaseQuantizer,
)


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


if __name__ == "__main__":
in_features = 32
out_features = 16
bias = True
shape = [64, in_features]

class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool):
super().__init__()
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)

def forward(self, x: torch.Tensor):
output_linear_out = self.output_linear(x)
return output_linear_out

model = QuantizedLinear(in_features, out_features, bias)
model.eval()

example_inputs = (torch.ones(shape),)

def export_xtensa_model(model, example_inputs):
# Quantizer
quantizer = XtensaQuantizer()
quantizer = XtensaBaseQuantizer()

# Export
model_exp = capture_pre_autograd_graph(model, example_inputs)
Expand All @@ -66,29 +47,17 @@ def forward(self, x: torch.Tensor):
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

# pre-autograd export. eventually this will become torch.export
converted_model_exp = capture_pre_autograd_graph(converted_model, example_inputs)
# Get edge program (note: the name will change to export_to_xtensa in future PRs)
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)

converted_model_exp = torch.ao.quantization.move_exported_model_to_eval(
converted_model_exp
# Run a couple required passes for quant/dequant ops
xtensa_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()]
)

exec_prog = (
export_to_edge(
converted_model_exp,
example_inputs,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
check_ir_validity=False,
)
.to_executorch()
)
exec_prog = xtensa_prog_manager.to_executorch(config=ExecutorchBackendConfig())

logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
logging.info(f"Final exported graph module:\n{exec_prog.exported_program().graph_module}")

# Save the program as XtensaDemoModel.pte
save_pte_program(exec_prog, "XtensaDemoModel")
99 changes: 90 additions & 9 deletions examples/xtensa/aot/meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

from .utils import get_conv1d_output_size

lib = Library("xtensa", "DEF")

lib.define(
Expand All @@ -25,10 +29,33 @@
)

lib.define(
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
)

lib.define(
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define(
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)"
)

lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
)
lib.define(
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
)

m = Library("xtensa", "IMPL", "Meta")
Expand Down Expand Up @@ -58,18 +85,17 @@ def dequantize_per_tensor_meta(
return input.new_empty(input.size(), dtype=torch.float)


@impl(m, "quantized_linear_pt2")
def quantized_linear_pt2_meta(
@impl(m, "quantized_linear")
def quantized_linear_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_scale: float,
in_zero_point: int,
weight_scale: float,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
weight_zero_point: torch.Tensor,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_zero_point: int,
offset: Optional[torch.Tensor],
):
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
Expand All @@ -79,3 +105,58 @@ def quantized_linear_pt2_meta(
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=torch.uint8)


@impl(m, "quantized_conv")
def quantized_conv_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
channel_last: bool = False,
):
out_channels, _in_channels, *kernel_size = weight.shape
in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = get_conv1d_output_size(
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
)

return input.new_empty(output_size, dtype=input.dtype)


@impl(m, "quantized_layer_norm")
def quantized_layer_norm_meta(
input: torch.Tensor,
X_scale: torch.Tensor,
X_zero_point: torch.Tensor,
normalized_shape: int,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
output_scale: float,
output_zero_point: int,
):
return input.new_empty(input.size(), dtype=torch.uint8)


@impl(m, "quantized_relu")
def quantized_relu_meta(
X: torch.Tensor,
X_zero_point: torch.Tensor,
):
return X.new_empty(X.size(), dtype=torch.uint8)
Loading

0 comments on commit 9653c8c

Please sign in to comment.