Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
877119f
Register clone_dim_order op; add test for op replacement
keyprocedure Jul 29, 2025
f75845d
Rename clone_dim_order op registration test
keyprocedure Jul 29, 2025
83d8c75
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Jul 29, 2025
cff39c9
Add graph level and end to end tests for _clone_dim_order op
keyprocedure Aug 3, 2025
1fe461f
Remove _clone_dim_order op registration (moved to PR #12974)
keyprocedure Aug 3, 2025
95db027
Register _clone_dim_order op
keyprocedure Aug 6, 2025
63d45e7
Merge branch 'main' into add-dim-order-clone-aot
keyprocedure Aug 11, 2025
d9a181c
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 12, 2025
c7caa27
Register _clone_dim_order as no-op in CoreML
keyprocedure Aug 13, 2025
e54605f
Remove redundant _clone_dim_order graph check
keyprocedure Aug 13, 2025
246bc44
Add _clone_dim_order to RemoveClonePass and update op name in tests
keyprocedure Aug 13, 2025
c48467c
Register _clone_dim_order under TOSA support check
keyprocedure Aug 13, 2025
e262a36
Merge branch 'main' into add-dim-order-clone-aot
digantdesai Aug 14, 2025
5546360
Add clone_dim_order_support to TOSA operator support list
keyprocedure Aug 16, 2025
5c5e65a
Register node visitor for _clone_dim_order
keyprocedure Aug 16, 2025
fe7dd11
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 19, 2025
7a0bc6a
Remove visitor node registration for _clone_dim_order
keyprocedure Aug 25, 2025
74e2cce
Remove aten.clone check from RemoveClonePass
keyprocedure Aug 25, 2025
f9f9515
Remove input dtype gating and add memory_format check
keyprocedure Aug 25, 2025
8d0cb06
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 25, 2025
6839212
Add Core ML test for _clone_dim_order
keyprocedure Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backends/apple/coreml/compiler/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs,
_get_kwinputs,
noop,
NUM_TO_NUMPY_DTYPE,
NUM_TO_TORCH_DTYPE,
split,
Expand Down Expand Up @@ -67,6 +68,28 @@ def _to_dim_order_copy(context, node):
to(context, node)


@register_torch_op(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we remove support to partition aten.clone op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't explicitly remove support for aten.clone. From my understanding this change would only add support for _clone_dim_order, aten.clone is still supported and handled by the noop handler in coremltools.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks OK to me, but do we have a test that covers it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for _clone_dim_order here. Is this what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

torch_alias=[
"dim_order_ops::_clone_dim_order",
"dim_order_ops._clone_dim_order",
],
override=False,
)
def _clone_dim_order(context, node):
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
node.kwinputs.pop("dim_order")

# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
dim_order = [int(d) for d in dim_order.val]
memory_format = get_memory_format(dim_order)
assert (
memory_format == _torch.contiguous_format
), "Only contiguous memory format is supported in CoreML"

# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
noop(context, node)
Copy link
Contributor

@metascroy metascroy Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keyprocedure Why not just call clone(context, node) here? It will be equivalent to noop now, but let's use clone if that's what the op is.

Just in case coremltools updates clone to be something other than a noop in future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of sense! But I don't see clone exposed in coremltools for import, it looks like it’s only handled as an alias to noop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is OK then



# https://github.com/apple/coremltools/pull/2558
@register_torch_op(
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
Expand Down
23 changes: 23 additions & 0 deletions backends/apple/coreml/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,28 @@ def test_dequantize_codebook_embedding(self):
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test__clone_dim_order_contiguous(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.dim_order_ops._clone_dim_order(
x, dim_order=[0, 1, 2, 3]
)

model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)


if __name__ == "__main__":
test_runner = TestTorchOps()
Expand All @@ -231,3 +253,4 @@ def test_dequantize_codebook_embedding(self):
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
test_runner.test_dequantize_codebook_linear()
test_runner.test_dequantize_codebook_embedding()
test_runner.test__clone_dim_order_contiguous()
2 changes: 1 addition & 1 deletion backends/arm/_passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RemoveClonePass(ExportPass):
"""Remove all clones from graph_module"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.clone.default:
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
return super().call_operator(op, args, kwargs, meta)

if len(args) != 1:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-unsafe

from . import ( # noqa
clone_dim_order_support,
convolution_support,
embedding_support,
ethos_u55_support,
Expand Down
76 changes: 76 additions & 0 deletions backends/arm/operator_support/clone_dim_order_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import logging

import torch
import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)


@register_tosa_support_check
class CloneDimOrderSupport(SupportedTOSAOperatorCheck):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we remove support to partition aten.clone op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't, aten.clone is still in the BaseTOSASupportList and supported. CloneDimOrderSupport would only apply when cloning takes place with _skip_dim_order=False, otherwise aten.clone would still exist.

targets = [
exir_ops.edge.dim_order_ops._clone_dim_order.default,
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
assert node.target in self.targets

# Check input type
assert len(node.all_input_nodes) == 1
input_val = node.all_input_nodes[0].meta["val"]
assert isinstance(input_val, torch._subclasses.FakeTensor)
input_dtype = input_val.dtype

# Check output type
output_val = node.meta["val"]
assert isinstance(output_val, torch._subclasses.FakeTensor)
if output_val.dtype != input_dtype:
self.reporter.report_reject(
node,
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
)
return False

# Check memory format
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
self.reporter.report_reject(
node,
f"Argument 'memory_format' is not supported for "
f"{node.target} right now.",
)
return False

# Check dim_order
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
self.reporter.report_reject(
node,
f"Argument {dim_order=} is not supported for "
f"{node.target} right now.",
)
return False

return True
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
]
linear_residual_exir_op: list[str] = [
"executorch_exir_dialects_edge__ops_aten_gelu_default",
"executorch_exir_dialects_edge__ops_aten_clone_default",
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing aten.clone here because the corresponding tests is dim order only?
@digantdesai will Arm use dim order mandatorily? If we want to continue support both aten.clone and dim_order.clone, lets keep the original test while have a new test for dim_order case.
Same as other tests

Copy link
Contributor Author

@keyprocedure keyprocedure Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ARM's pass manager doesn't seem to pass the _skip_dim_order flag, it defaults to False, so my strategy was to replace all instances of aten.clone with _clone_dim_order in the tests. Although it might make more sense to explicitly set _skip_dim_order=True in the EdgeCompileConfig calls in arm_tester and undo my changes to the ARM test files. Then I can either continue with adding _clone_dim_order support and tests (following Gasoonjia’s suggestion) in TOSA for future use or leave it out for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah my point is we should make our update in line with Arm's target in a consistent way.

That is, if we expect dim-order-only in Arm backend, there should be no aten.clone in our PR and we should remove aten.clone supports.

However if we should support dim_order equals to both on and off, we may test both case for clone operator, though I'm ok to split it into several PRs for better structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks for breaking down the options.

Copy link
Contributor

@digantdesai digantdesai Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI see - 135e875

Also cc @oscarandersson8218 for what if no aten.clone support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai thanks for looking into this and sharing the enable dim_order PR. It looks like the best strategy is to replace support for aten.clone with _clone_dim_order in ARM since the _skip_dim_order flag will always be False, following @Gasoonjia's suggestion.

"executorch_exir_dialects_edge__ops_aten_linear_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
]
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

aten_op = "torch.ops.aten.clone.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default"
exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"

input_t = Tuple[torch.Tensor]

Expand Down
6 changes: 4 additions & 2 deletions backends/arm/test/passes/test_remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT():
module.get_inputs(),
quantize=True,
ops_before_pass={
"executorch_exir_dialects_edge__ops_aten_clone_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1,
},
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"],
ops_not_after_pass=[
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
],
pass_list=[RemoveClonePass],
)
pipeline.run()
19 changes: 19 additions & 0 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

lib.define(
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)


def _op_impl(target, *args, **kwargs):
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
Expand Down Expand Up @@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)


@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
def _clone_dim_order_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)


@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
def _clone_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)


"""
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

"""
Expand Down
52 changes: 52 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
AmbiguousDimOrderError,
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCloneChannelsLastModule,
PropagateToCopyChannalsLastModule,
SimpleCloneChannelsLastModule,
SimpleCloneContiguousModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
Expand Down Expand Up @@ -91,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
),
)

def test_op_clone_replacement_contiguous(self) -> None:
model = SimpleCloneContiguousModule()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_clone_replacement_channels_last(self) -> None:
model = SimpleCloneChannelsLastModule()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
Expand Down Expand Up @@ -128,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None:
check_unambiguous_dim_order=True,
)

def test_op_clone_dim_order_propagation(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=PropagateToCloneChannelsLastModule().eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
dtype=torch.float32,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
check_unambiguous_dim_order=True,
)

def test_op_dim_order_propagation_ambiguous(self) -> None:
try:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
Expand Down
30 changes: 30 additions & 0 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"torch.ops.aten.empty.memory_format",
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
),
torch.ops.aten.clone.default: (
"torch.ops.aten.clone.default",
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
}


Expand Down Expand Up @@ -70,6 +74,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)


class SimpleCloneContiguousModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.clone(memory_format=torch.contiguous_format)


class SimpleCloneChannelsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.clone(memory_format=torch.channels_last)


class SimpleEmptyContiguoustModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -102,6 +122,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return t1 * t2


class PropagateToCloneChannelsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
t1 = x.clone(memory_format=torch.channels_last)
t2 = t1 + t1
return t1 * t2


class AmbiguousDimOrderError(RuntimeError):
pass

Expand Down
Loading