Skip to content

Commit 22e7dbd

Browse files
authored
Fix triu/tril CoreML lowering error in to_edge_transform_and_lower (#11107)
This PR fixes a CoreML lowering issue with triu/tril when using to_edge_transform_and_lower. It also changes the warning logging in _remove_invalid_ops_for_not_decompose to log only once per call, rather than once per op.
1 parent ea8b4e1 commit 22e7dbd

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,33 @@ def ops_to_not_decompose(
110110
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111111
do_not_decompose = []
112112
op_support = OperatorsSupportedForCoreMLBackend()
113+
_logged_warnings = set()
114+
115+
# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
116+
# TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the
117+
# desired changes, we need to manually block them here
118+
do_not_decompose_blocklist = [
119+
# https://github.com/apple/coremltools/blob/release/8.3/coremltools/converters/mil/frontend/torch/ops.py#L6965-L6966
120+
torch.ops.aten.triu.default,
121+
# https://github.com/apple/coremltools/blob/release/8.3/coremltools/converters/mil/frontend/torch/ops.py#L6997-L6998
122+
torch.ops.aten.tril.default,
123+
]
113124
for node in ep.graph.nodes:
114125
if node.op == "call_function" and isinstance(
115126
node.target, torch._ops.OpOverload
116127
):
117128
try:
118-
if op_support.is_node_supported(None, node):
129+
if (
130+
op_support.is_node_supported(None, node)
131+
and node.target not in do_not_decompose_blocklist
132+
):
119133
do_not_decompose.append(node.target)
120134
except Exception as e:
121135
# CoreML's op_support.is_node_supported will sometimes throw
122136
# for unsupported ops, rather than returning False
123-
logger.warning(
124-
f"Encountered exception when checking node support: {e}"
125-
)
137+
warn_str = f"Encountered exception when checking node support: {e}"
138+
if warn_str not in _logged_warnings:
139+
logger.warning(warn_str)
140+
_logged_warnings.add(warn_str)
141+
126142
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def forward(self, q, k, v, mask):
9090
q, k, v, attn_mask=mask
9191
)
9292

93+
# triu/tril should be ignored by do_not_decompose
94+
# because otherwise they fail during CoreML lowering
95+
offset1 = torch.triu(mask, diagonal=1)
96+
offset2 = torch.tril(mask)
97+
offset = offset1 + offset2
98+
offset = torch.sum(offset)
99+
93100
# Add non-functional and alias ops
94101
# These will be removed by ExecuTorch in non-decomposition
95102
# table because they cannot be functionalized
@@ -102,7 +109,7 @@ def forward(self, q, k, v, mask):
102109
out = out.sub_(4.0)
103110
out = torch.ops.aten.view_copy.default(out, (-1,))
104111
out = out.select(0, 0)
105-
return out
112+
return out + offset
106113

107114
model = Model()
108115
model.eval()
@@ -118,6 +125,13 @@ def forward(self, q, k, v, mask):
118125
mask = torch.randn(seq_len, max_seq_length)
119126
example_inputs = (q, k, v, mask)
120127
ep = torch.export.export(model, example_inputs, strict=True)
128+
self.assertTrue(
129+
"torch.ops.aten.triu.default" in ep.graph_module.code,
130+
)
131+
self.assertTrue(
132+
"torch.ops.aten.tril.default" in ep.graph_module.code,
133+
)
134+
121135
coreml_partitioner = CoreMLPartitioner()
122136

123137
# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph

exir/program/_program.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,13 @@ def _sanity_check_graph_for_non_decomp_ops(
10171017
def _remove_invalid_ops_for_not_decompose(
10181018
ops_to_not_decompose: List[torch._ops.OpOverload],
10191019
) -> List[torch._ops.OpOverload]:
1020+
_logged_warnings = set()
1021+
1022+
def log_warning(warn_str):
1023+
if warn_str not in _logged_warnings:
1024+
logging.warn(warn_str)
1025+
_logged_warnings.add(warn_str)
1026+
10201027
# To address https://github.com/pytorch/executorch/issues/8781
10211028
def keep(op):
10221029
# Explicit allow list
@@ -1034,18 +1041,18 @@ def keep(op):
10341041
schema = op._schema
10351042
native_schema = _pybind_schema_to_native_schema(schema)
10361043
if native_schema is None:
1037-
logging.warn(
1044+
log_warning(
10381045
f"Torchgen is not able to parse the schema of {op._schema}. This is not fatal."
10391046
)
10401047
else:
10411048
if native_schema.is_mutable:
1042-
logging.warn(
1049+
log_warning(
10431050
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable."
10441051
)
10451052
return False
10461053

10471054
if native_schema.aliased_return_names() != [None]:
1048-
logging.warn(
1055+
log_warning(
10491056
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
10501057
)
10511058
return False
@@ -1067,7 +1074,7 @@ def keep(op):
10671074
torch.ops.aten.unbind.int,
10681075
torch.ops.aten.split_with_sizes.default,
10691076
]:
1070-
logging.warn(
1077+
log_warning(
10711078
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist."
10721079
)
10731080
return False

0 commit comments

Comments
 (0)