Skip to content
25 changes: 22 additions & 3 deletions coremltools/converters/mil/frontend/torch/internal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from collections import OrderedDict
from collections import OrderedDict, deque
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -584,9 +584,28 @@ def from_exir(cls, exir, cut_at_symbols=None):
else:
outputs = user_outputs

preserved_node_targets = set([
torch.ops.aten._assert_tensor_metadata.default, # dtype of some ops
])
graph_nodes = exported_program.graph_module.graph.nodes
user_count = {n: len(n.users) for n in graph_nodes}
def _is_skip(node: torch.fx.Node):
return user_count[node] <= 0 and node.op == "call_function" and node.target not in preserved_node_targets

skip_nodes = set()
node_que = deque([n for n in graph_nodes if _is_skip(n)])
while node_que:
n = node_que.popleft()
skip_nodes.add(n)
for in_n in n.all_input_nodes:
user_count[in_n] -= 1
if _is_skip(in_n): node_que.append(in_n)

nodes = []
for node in exported_program.graph_module.graph.nodes:
if node.op == "call_function":
for node in graph_nodes:
if node in skip_nodes:
continue
elif node.op == "call_function":
nodes.append(InternalTorchIRNode.from_exir_node(node=node))
elif node.op == "get_attr":
name = node.target
Expand Down
38 changes: 23 additions & 15 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,12 +1704,14 @@ def _parse_positional_args(context, node) -> Tuple[Var]:

x = inputs[0]
alpha = 1 if nargs < 2 else inputs[1]
scale = None if nargs < 3 or context.frontend != TorchFrontend.EXECUTORCH else inputs[2]
return x, alpha, scale

return x, alpha

x, alpha = _parse_positional_args(context, node)

res = mb.elu(x=x, alpha=alpha, name=node.name)
x, alpha, scale = _parse_positional_args(context, node)
res = mb.elu(x=x, alpha=alpha)
if scale is not None:
res = mb.mul(x=res, y=scale)
res.name = node.name
context.add(res)


Expand Down Expand Up @@ -5505,7 +5507,7 @@ def _parse_keyword_args(context, node, dtype) -> Var:

@register_torch_op
def ones_like(context, node):
def _parse_positional_args(context, node) -> Tuple[Var, Optional[Var]]:
def _parse_positional_args(context, node) -> Tuple[Var, Optional[str]]:
inputs = _get_inputs(
context,
node,
Expand All @@ -5516,22 +5518,28 @@ def _parse_positional_args(context, node) -> Tuple[Var, Optional[Var]]:
dtype = None
if len(inputs) > 1 and inputs[1] is not None:
dtype = inputs[1]
if dtype is None:
dtype = _get_kwinputs(context, node, "dtype", default=[dtype])[0]
if dtype is None and node.meta is not None:
dtype = TORCH_DTYPE_TO_NUM[node.meta['tensor_meta'].dtype]
if isinstance(dtype, Var): dtype = dtype.val.item()
if isinstance(dtype, int): dtype = NUM_TO_DTYPE_STRING[dtype]
if dtype is None: dtype = types.builtin_to_string(x.dtype)
return x, dtype

def _parse_keyword_args(context, node, dtype) -> Var:
dtype = _get_kwinputs(context, node, "dtype", default=[dtype])[0]
return dtype

x, dtype = _parse_positional_args(context, node)
dtype = _parse_keyword_args(context, node, dtype)

if is_current_opset_version_compatible_with(target.iOS16):
res = mb.fill_like(ref_tensor=x, value=1.0)
v = {
"fp16": np.float16(1.0),
"fp32": np.float32(1.0),
"int32": np.int32(1),
"bool": np.bool_(True),
}.get(dtype, 1.0)
res = mb.fill_like(ref_tensor=x, value=v)
else:
res = mb.fill(shape=mb.shape(x=x), value=1.0)
# By default use input x's dtype.
dtype_str = NUM_TO_DTYPE_STRING[dtype.val] if dtype is not None else types.builtin_to_string(x.dtype)
res = _cast_to(res, dtype_str, node.name)
res = _cast_to(res, dtype, node.name)
context.add(res, node.name)


Expand Down
15 changes: 7 additions & 8 deletions coremltools/converters/mil/frontend/torch/quantization_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,20 +691,19 @@ def _weight_int4pack_mm(context, node):
.contiguous()
)
with _torch.no_grad():
zero_point_domain = (
torchao_quant.ZeroPointDomain.INT
if _np.issubdtype(zero_points.dtype, _np.integer)
else torchao_quant.ZeroPointDomain.FLOAT
)
scale = _torch.from_numpy(scales)
zero_point = _torch.from_numpy(zero_points)
if scale.dtype == _torch.float32:
mid_point = (quant_max + quant_min + 1) / 2
zero_point = _torch.round(mid_point - zero_point / scale).to(_torch.int32)
y_quantized = torchao_quant.quantize_affine(
y_dequantized,
(1, group_size),
_torch.from_numpy(scales),
_torch.from_numpy(zero_points),
scale,
zero_point,
_torch.int32,
quant_min=quant_min,
quant_max=quant_max,
zero_point_domain=zero_point_domain,
)
y_quantized = y_quantized.numpy().astype(_np.uint8)
if len(y_quantized.shape) != 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def forward(self, input, other):
mul = mil_program.functions["main"].find_ops(op_type="mul")[0]

stack_trace = mul.scopes[ScopeSource.EXIR_STACK_TRACE][0]
assert stack_trace.split("\n")[-2].strip() == "return input * other"
assert stack_trace.split("\n")[-1].strip() == "return input * other"

if frontend == TorchFrontend.EXECUTORCH:
debug_handle = mul.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0]
Expand Down Expand Up @@ -464,7 +464,7 @@ def forward(self, arg):
linear = mil_program.functions["main"].find_ops(op_type="linear")[0]

stack_trace = linear.scopes[ScopeSource.EXIR_STACK_TRACE][0]
assert stack_trace.split("\n")[-2].strip() == "return self.linear(arg)"
assert stack_trace.split("\n")[1].strip() == "return self.linear(arg)"

if frontend == TorchFrontend.EXECUTORCH:
debug_handle = linear.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0]
Expand Down Expand Up @@ -555,7 +555,7 @@ def forward(self, x, y):
"z = z + z",
]
for i, stack_trace in enumerate(stack_traces):
assert stack_trace.split("\n")[-2].strip() == source_codes[i]
assert stack_trace.strip().split("\n")[-1].strip() == source_codes[i]

if frontend == TorchFrontend.EXECUTORCH:
debug_handles = [add.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] for add in adds]
Expand Down Expand Up @@ -674,7 +674,7 @@ def forward(self, a, x, b):
for op_type in ("matmul", "add"):
stack_trace = stack_traces[op_type]
source_code = source_codes[op_type]
assert stack_trace.split("\n")[-2].strip() == source_code
assert stack_trace.strip().split("\n")[-1].strip() == source_code

if frontend == TorchFrontend.EXECUTORCH:
debug_handle = {
Expand Down Expand Up @@ -767,7 +767,7 @@ def forward(self, x):
softmax = mil_program.functions["main"].find_ops(op_type="softmax")[0]

stack_trace = softmax.scopes[ScopeSource.EXIR_STACK_TRACE][0]
assert stack_trace.split("\n")[-2].strip() == "return self.softmax(x)"
assert stack_trace.split("\n")[1].strip() == "return self.softmax(x)"

if frontend == TorchFrontend.EXECUTORCH:
debug_handle = softmax.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0]
Expand Down
15 changes: 8 additions & 7 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,15 +1204,14 @@ def test_batchnorm_dynamic(
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, has_weight, has_bias, has_running_mean, has_running_var",
"compute_unit, backend, frontend, has_weight, has_bias, has_running_mean_and_var",
itertools.product(
compute_units,
backends,
frontends,
[True, False],
[True, False],
[True, False],
[True, False],
),
)
def test_batchnorm_dynamic_stress(
Expand All @@ -1222,8 +1221,7 @@ def test_batchnorm_dynamic_stress(
frontend,
has_weight,
has_bias,
has_running_mean,
has_running_var,
has_running_mean_and_var,
):
if frontend in TORCH_EXPORT_BASED_FRONTENDS:
pytest.skip("torch.export converter does not handle input mutation")
Expand All @@ -1233,8 +1231,8 @@ def test_batchnorm_dynamic_stress(

weight = torch.randn(num_features) if has_weight else None
bias = torch.randn(num_features) if has_bias else None
running_mean = torch.randn(num_features) if has_running_mean else None
running_var = torch.randn(num_features) if has_running_var else None
running_mean = torch.randn(num_features) if has_running_mean_and_var else None
running_var = torch.randn(num_features) if has_running_mean_and_var else None

class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -13216,7 +13214,10 @@ def test_unfold(
input_type, dynamic_shapes = None, None
if is_dynamic_hw:
h_coreml, w_coreml = RangeDim(min_h, 128), RangeDim(min_w, 128)
h_torch, w_torch = torch.export.Dim("h", min=min_h, max=128), torch.export.Dim("w", min=min_w, max=128)
if platform.machine() == "x86_64":
h_torch, w_torch = torch.export.Dim("h", min=min_h, max=128), torch.export.Dim("w", min=min_w, max=128)
else:
h_torch, w_torch = torch.export.Dim.AUTO, torch.export.Dim.AUTO
input_type = [ct.TensorType(name="x", shape=ct.Shape([input_shape[0], input_shape[1], h_coreml, w_coreml]))]
dynamic_shapes = {"args": ((input_shape[0], input_shape[1], h_torch, w_torch),)}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,28 +603,30 @@ def test_unpack_int4packed_by_mm_with_eye_matrix(self, use_numpy, inner_k_tiles,
scales = torch.transpose(y_scales_and_zeros[:, :, 0], 0, 1)
zero_points = torch.transpose(y_scales_and_zeros[:, :, 1], 0, 1)
block_size = (1, group_size)

quant_min, quant_max = 0, 2**4 - 1
mid_point = (quant_max + quant_min + 1) / 2
zero_points = mid_point - zero_points / scales
y_dequant_quantized = torchao_quant.quantize_affine(
y_dequant,
block_size,
scales,
zero_points,
torch.round(zero_points).to(torch.int32),
torch.int32,
quant_min=0,
quant_max=2**4 - 1,
zero_point_domain=torchao_quant.ZeroPointDomain.FLOAT,
quant_min=quant_min,
quant_max=quant_max,
)
assert torch.equal(y_quantized, y_dequant_quantized)

# The torchao dequantization utils should be able to recover the original y.
y_dequantized_by_torchao = torchao_quant.dequantize_affine(
y_quantized,
(1, group_size),
block_size,
scales,
zero_points,
torch.int32,
quant_min=0,
quant_max=2**4 - 1,
zero_point_domain=torchao_quant.ZeroPointDomain.FLOAT,
quant_min=quant_min,
quant_max=quant_max,
)
np.testing.assert_allclose(y_dequant.numpy(), y_dequantized_by_torchao.numpy(), rtol=4e-3)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from coremltools._deps import _HAS_TORCH_EXPORT_API
if _HAS_TORCH_EXPORT_API:
from torch.export import export_for_training
from torch.ao.quantization.quantize_pt2e import (
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
Expand Down Expand Up @@ -164,6 +164,10 @@ def quantize_model(
):
quantizer = CoreMLQuantizer(quantization_config)
exported_model = export_for_training(model, (data,)).module()
for node in exported_model.graph.nodes:
if "source_fn_stack" not in node.meta:
node.meta["source_fn_stack"] = [("dummy", nn.Module)]

if is_qat:
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
else:
Expand Down
4 changes: 2 additions & 2 deletions reqs/pytorch.pip
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ torchaudio==2.2.0; platform_machine != "arm64"
torchvision==0.17.0; platform_machine != "arm64"

# Torch dependencies for ARM
torch>=2.2.0,<=2.7.0; platform_machine == "arm64"
torch~=2.8.0; platform_machine == "arm64"
torchaudio>=2.2.0; platform_machine == "arm64"
torchvision>=0.17.0; platform_machine == "arm64"
torchsr==1.0.4; platform_machine == "arm64"

# TODO (rdar://141476729) support a more recent timm
timm==0.6.13; platform_machine == "arm64"

torchao==0.10.0; platform_machine == "arm64" and python_version >= '3.10'
torchao==0.12.0; platform_machine == "arm64" and python_version >= '3.10'
2 changes: 1 addition & 1 deletion reqs/test_executorch.pip
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

# Warning: Starting from ExecuTorch 0.6.0, coremltools is added as a dependency
# so we need to re-install built-from-source coremltools after pip install ExecuTorch
executorch>=0.6.0; platform_machine == "arm64" and python_version >= '3.10' and python_version < '3.13'
executorch>=0.7.0; platform_machine == "arm64" and python_version >= '3.10' and python_version < '3.13'