Skip to content

[ExecuTorch][Weight Sharing][XNNPACK] load named data map data for xnnpack #9294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/xnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
# Keeping this OFF by default due to regressions in decode and model load with
# kleidi kernels
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)

# Turning this on cache weights between partitions and methods. If weights
# are shared across methods/partitions then this can reduce load time and
# memory usage

# Keeping this off maintains existing behavior. Turning this on serializes
# execution and initialization of delegates, to be revisited
option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE
"Enable weights cache to cache and manage all packed weights" OFF)

if(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE)
add_definitions(-DENABLE_XNNPACK_WEIGHTS_CACHE)
endif()
if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE)
add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE)
endif()
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ python_library(
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:memory_format_ops_pass",
"//executorch/exir/program:program",
"//executorch/backends/transforms:utils",
],
)
68 changes: 52 additions & 16 deletions backends/xnnpack/_passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@
import operator

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)

from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass

from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
from executorch.backends.xnnpack.utils.utils import (
get_param_tensor,
get_tensor_name,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind

from torch.nn.utils.fusion import fuse_conv_bn_weights

Expand All @@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
counter = 0
constant_placeholders_to_delete = set()
for conv in graph.nodes:
# We want to discover a chain of conv -> batch_norm.
# Only proceed if the current node is a conv node, and has a single
Expand All @@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
assert len(conv.args) == 9

conv_weight = get_param_tensor(self.exported_program, conv.args[1])
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
assert conv_weight is not None

conv_bias = get_param_tensor(self.exported_program, conv.args[2])
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])

# Get the parameters from the batchnorm op
assert (
Expand Down Expand Up @@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
bn_bias,
is_transpose,
)
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
if conv_bias_name == "":
fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace(
".", "_"
)
else:
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")

# Modify the graph by updating the weight and bias of conv op
# with the fused weight and bias params, and replacing all the users
# of getitem(batchnorm) with the conv op.
with graph.inserting_before(conv):
fused_weight_name = f"_fused_with_bn_weight_{counter}"
graph_module.register_parameter(fused_weight_name, fused_weight)
fused_weight_node = graph.get_attr(fused_weight_name)
fused_bias_name = f"_fused_with_bn_bias_{counter}"
graph_module.register_parameter(fused_bias_name, fused_bias)
fused_bias_node = graph.get_attr(fused_bias_name)

# Update the weight and bias of conv op
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
conv_args[1] = fused_weight_node
conv_args[2] = fused_bias_node
conv.args = tuple(conv_args)
with graph.inserting_before(conv.args[1]):
fused_conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_weight_name,
data=fused_weight,
)
if fused_bias is not None:
fused_conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_bias_name,
data=fused_bias,
)
else:
fused_conv_bias_node = None

conv.args = (
conv.args[0],
fused_conv_weight_node,
fused_conv_bias_node,
*conv.args[3:],
)

# Remove any use of batchnorm from the graph
for user in bn.users.copy():
assert user.target == operator.getitem
user.replace_all_uses_with(conv)
graph.erase_node(user)

graph.erase_node(bn)
constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5])

counter += 1
if len(constant_placeholders_to_delete) > 0:
graph_module.graph.eliminate_dead_code()
for node in constant_placeholders_to_delete:
if (node is not None) and (len(node.users) == 0):
delete_constant_placeholder(self.exported_program, node)

graph_module.recompile()
# To Regenerate meta data and shape information, retrace module
Expand Down
26 changes: 17 additions & 9 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
check_or_raise,
get_input_node,
get_param_tensor,
get_tensor_name,
is_param_node,
PERM_NCHW_TO_NHWC,
)

from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
from executorch.backends.xnnpack.utils.xnnpack_constants import (
UINT64_MAX,
XNN_INVALID_VALUE_ID,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from torch.export import ExportedProgram

XNN_TYPE_MAP = {
torch.float32: XNNDatatype.xnn_datatype_fp32,
}

from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
_aligned_size,
_pad_to,
CONSTANT_TENSOR_ALIGNMENT,
)

Expand Down Expand Up @@ -86,11 +89,11 @@ def __init__(
self,
exported_program: ExportedProgram,
external_ids: Dict,
constant_data_bytes: bytearray,
named_data_store: NamedDataStore,
) -> None:
self._external_ids = external_ids or {}
self._exported_program = exported_program or None
self._constant_data_bytes = constant_data_bytes
self._named_data_store = named_data_store

@property
def external_ids(self) -> Dict:
Expand Down Expand Up @@ -579,11 +582,16 @@ def get_serialized_buffer_index(
ctypes.POINTER(array_type),
).contents

offset = len(self._constant_data_bytes)
named_key = get_tensor_name(self.exported_program, get_attr_node)
if named_key == "":
raise ValueError(f"Tensor from node: {get_attr_node} has no name")

size = const_val.untyped_storage().nbytes()
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
self._constant_data_bytes.extend(
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
xnn_graph.constant_data.append(
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
)
self._named_data_store.add_named_data(
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
)

return buffer_idx
Expand Down
Loading
Loading