Skip to content

Commit 560d344

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK][AOT] Serialize constant tensors via NamedDataMap"
Summary: When exporting models to Vulkan backend, save constant tensors in the NamedDataMap instead of the constant data section of the delegate header. ## Motivation Prevent screen blackout (Llama 3.2 1B) / device crash (Llama 3.2 3B) when running Llama 3.2 models on Samsung Galaxy S24. This behaviour is related to high peak memory usage when loading the model. For more information, see the top diff/PR in the stack. ## Context This change is based on the equivalent change D70315207/#9153 in XNNPACK. Test Plan: ## Memory Comparison with/without NamedDataMap Measured VmRss using ``` uint64_t getVmRssInKB() { std::ifstream statusFile("/proc/self/status"); std::string l, num; while (std::getline(statusFile, l)) { if (l.substr(0, 5) == "VmRSS") { size_t pos = l.find_first_of("0123456789"); num = l.substr(pos); break; } } uint64_t vmRssInKB = std::stoi(num); return vmRssInKB; } ``` P1908019767 (Meta only) Excerpt: ``` Log 1 | Log 2 --------------------------------------------------|-------------------------------------------------- Memory usage before model compilation: 1115416 KB | Memory usage before model compilation: 1919228 KB Memory usage after graph building: 1924340 KB | Memory usage after graph building: 1924256 KB Memory usage after graph preparation: 1798968 KB | Memory usage after graph preparation: 1782464 KB Memory usage prepack start: 1798968 KB | Memory usage prepack start: 1781968 KB Memory usage after prepack operations: 1271924 KB | Memory usage after prepack operations: 1653496 KB ``` Differential Revision: [D80460034](https://our.internmc.facebook.com/intern/diff/D80460034) [ghstack-poisoned]
1 parent c5677ff commit 560d344

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+732
-1440
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
3434
from .decompose_cosh_pass import DecomposeCoshPass # noqa
3535
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
36-
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3736
from .decompose_div_pass import DecomposeDivPass # noqa
3837
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3938
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,36 @@
1414
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass, PassResult
17+
from torch.library import impl, Library
18+
19+
# Define lib with passthrough operators. The operators have no real meaning in edge IR
20+
# except for argument validaiton and a passthrough output. The operators will be used
21+
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
22+
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
23+
lib = Library("passthrough_to_tosa", "DEF")
24+
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
25+
# is not sufficient as we also need transpose the data.
26+
# By utilizing an edge IR passthrough operator we can keep the edge program in
27+
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
28+
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
29+
30+
31+
@impl(lib, "_transpose")
32+
def _transpose_impl(*args, **kwargs):
33+
# Validate length of dim_order array
34+
dim = args[1]
35+
if len(dim) != 4 and len(dim) != 5:
36+
raise ValueError(
37+
f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}"
38+
)
39+
# Pass-through in edge-IR
40+
return args[0]
1741

1842

1943
class AnnotateChannelsLastDimOrder(ExportPass):
2044
"""
2145
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
22-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
46+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
2347
when a transition between 3D and 4D/5D tensors happen.
2448
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
2549
"""
@@ -95,7 +119,7 @@ def insert_input_transpose(node, input_node, graph_module):
95119
with graph_module.graph.inserting_before(node):
96120
permute_node = create_node(
97121
graph_module.graph,
98-
exir_ops.backend.tosa.TRANSPOSE.default,
122+
torch.ops.passthrough_to_tosa._transpose.default,
99123
args=(
100124
input_node,
101125
list(
@@ -117,7 +141,7 @@ def insert_output_transpose(node, graph_module):
117141
with graph_module.graph.inserting_after(node):
118142
permute_node = create_node(
119143
graph_module.graph,
120-
exir_ops.backend.tosa.TRANSPOSE.default,
144+
torch.ops.passthrough_to_tosa._transpose.default,
121145
args=(
122146
node,
123147
list(

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
DecomposeBatchNormNoStatsPass,
3939
DecomposeCoshPass,
4040
DecomposeCosineSimilarityPass,
41-
DecomposeCumsumPass,
4241
DecomposeDivPass,
4342
DecomposeEmbeddingPass,
4443
DecomposeExpm1Pass,
@@ -149,7 +148,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
149148
self.add_pass(UnsqueezeBeforeRepeatPass())
150149
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
151150
self.add_pass(DecomposeSumPass())
152-
self.add_pass(DecomposeCumsumPass(exported_program))
153151
self.add_pass(Conv1dUnsqueezePass())
154152
self.add_pass(DecomposeMaxPool2DPass())
155153
self.add_pass(SizeAdjustInputPass())
@@ -229,7 +227,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
229227
self.add_pass(UnsqueezeBeforeRepeatPass())
230228
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
231229
self.add_pass(DecomposeSumPass())
232-
self.add_pass(DecomposeCumsumPass(exported_program))
233230
self.add_pass(Conv1dUnsqueezePass())
234231
self.add_pass(DecomposeMaxPool2DPass())
235232
self.add_pass(SizeAdjustInputPass())

backends/arm/_passes/decompose_cumsum_pass.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def call(self, graph_module):
107107
for node in graph_module.graph.nodes:
108108
if node.op != "call_function":
109109
continue
110-
if node.target == exir_ops.backend.tosa.TABLE.default:
110+
if node.target == torch.ops.tosa._table.default:
111111
continue
112112

113113
input_nodes = node.all_input_nodes

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,70 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
from copy import copy
78
from typing import cast
89

10+
import torch
911
from executorch.backends.arm._passes.arm_pass_utils import create_node
1012
from executorch.backends.arm._passes.quant_args import QuantArgs
1113
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
12-
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torch import Tensor
1416
from torch.fx import GraphModule, Node
17+
from torch.library import custom_op, register_fake
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
23+
def rescale(
24+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
25+
) -> Tensor:
26+
logger.warning(
27+
"Ran default implementation of tosa::_rescale."
28+
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
29+
)
30+
# Clone is needed to not return reference when rescaling to same dtype.
31+
# This is a neccessary requirement for non-mutating custom ops.
32+
return x.to(dtype=dtype).clone()
33+
34+
35+
@register_fake("tosa::_rescale") # type: ignore[misc]
36+
def rescale_fake(
37+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
38+
) -> Tensor:
39+
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
40+
Additionally validates TOSA constraints of a RESCALE op.
41+
"""
42+
if dtype not in (torch.int32, torch.int8, torch.int16):
43+
raise NotImplementedError(
44+
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
45+
)
46+
if dtype in (torch.int32, torch.int16) and out_zp != 0:
47+
raise ValueError(
48+
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
49+
)
50+
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
51+
raise ValueError(
52+
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
53+
)
54+
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
55+
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
56+
if dtype == torch.int8 and not -128 <= out_zp <= 127:
57+
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
58+
59+
return x.to(dtype=dtype).clone()
1560

1661

1762
class InsertRescalePass(ExportPass):
1863
"""Finds patterns of dq -> q, and replaces them
19-
with backend dialect tosa::RESCALE op.
64+
with passthrough_to_tosa::rescales.
2065
21-
Does not guarantee that the dtypes and zero points are valid
66+
Does not garantuee that the dtypes and zero points are valid
2267
in TOSA, that is the job of the quantization annotator that
2368
produced the dq and q nodes. The TOSA constraints are validated
24-
in the fake implementation of.
69+
in the fake implementation of passthrough_to_tosa:rescale.
2570
"""
2671

2772
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
@@ -32,7 +77,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
3277
with graph_module.graph.inserting_before(node):
3378
rescale_node = create_node(
3479
graph_module.graph,
35-
exir_ops.backend.tosa.RESCALE.default,
80+
torch.ops.tosa._rescale.default,
3681
(
3782
node.all_input_nodes[0],
3883
q_args.dtype,

0 commit comments

Comments
 (0)