Skip to content

Conversation

@M-Quadra
Copy link
Contributor

Adapt torch 2.8, torchao 0.12.0, executorch 0.7.0.

Environment

executorch==0.7.0
├── torch [required: >=2.8.0,<2.9.0, installed: 2.8.0]
└── torchao [required: ==0.12.0, installed: 0.12.0]

Features and Bug Fixes

Fix test_batchnorm_dynamic_stress

Fix test_batchnorm_dynamic_stress

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestBatchNorm::test_batchnorm_dynamic_stress

For torch.nn.functional.batch_norm:

E       ValueError: running_mean and running_var must either both be None or neither be None
Adapt ones_like dtype

Adapt ones_like dtype for torch 2.8.0

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestIndexPut::test_index_put_updates_bool
ValueError: In op, of type scatter_nd, named index_put, the named input `updates` must have the same data type as the named input `data`. However, updates has dtype fp32 whereas data has dtype int32.
import torch
import numpy as np
import coremltools as ct


class Model(torch.nn.Module):
    def forward(self, x):
        x = torch.ones(x.shape, dtype=torch.bool)
        y = torch.ones_like(x).bool()
        mask = torch.tensor([True, False, False, False, True, True]).view(3, 2)
        x[mask] = y[mask]
        return x


x = torch.randn(3, 2)
model = Model().eval()

exported_model = torch.export.export(model, (x,)).run_decompositions({})
mlmodel = ct.convert(
    exported_model,
    minimum_deployment_target=ct.target.iOS16,
)

y0 = model(x).numpy()
y1 = mlmodel.predict({"x": x.numpy()})["index_put"]
assert np.equal(y0, y1).all()
  • exported_model.graph in torch 2.7.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%ones_like,), kwargs = {dtype: torch.bool})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)
  • exported_model.graph in torch 2.8.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=2] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%ones_like, None, None, torch.bool), kwargs = {device: cpu, layout: torch.strided})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%ones_like, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)

In torch 2.8.0, the dtype of ones_like may be moved to _assert_tensor_metadata.

Fix test_unfold

Fix test_unfold

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestUnfold::test_unfold
E                   torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (w)! For more information, run with TORCH_LOGS="+dynamic".
E                     - Not all values of w = L['args'][0].size()[3] in the specified range 3 <= w <= 128 satisfy the generated guard max(1, (1 + L['args'][0].size()[3]) // 2) == ((1 + L['args'][0].size()[3]) // 2).

To fix the unit test, replace the following code:

h_torch, w_torch = torch.export.Dim("h", min=min_h, max=128), torch.export.Dim("w", min=min_w, max=128)

with:

h_torch, w_torch = torch.export.Dim.AUTO, torch.export.Dim.AUTO
Adapt torch.nn.SELU for executorch 0.7.0

Adapt torch.nn.SELU for executorch 0.7.0

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestActivation::test_selu
>               np.testing.assert_allclose(coreml_result, torch_result, atol=atol, rtol=rtol)
E               AssertionError: 
E               Not equal to tolerance rtol=1e-05, atol=0.0001
E               
E               Mismatched elements: 6 / 7 (85.7%)
E               Max absolute difference among violations: 0.3042059
E               Max relative difference among violations: 0.04825448
E                ACTUAL: array([-1.669116, -1.642616, -1.446812,  0.      ,  2.      ,  4.      ,
E                       6.      ], dtype=float32)
E                DESIRED: array([-1.753741, -1.725899, -1.520167,  0.      ,  2.101402,  4.202804,
E                       6.304206], dtype=float32)

coremltools/converters/mil/frontend/torch/test/testing_utils.py:315: AssertionError
import torch
import executorch.exir


x = torch.tensor([-6.0, -4.0, -2.0, 0.0, 2.0, 4.0, 6.0])
model = torch.nn.SELU(inplace=False).eval()
model_spec = torch.export.export(
    model, (x,)
).run_decompositions({})
model_spec = executorch.exir.to_edge(model_spec).exported_program()
print(model_spec.graph)
  • graph in executorch 0.7.0
%input_1 : [num_users=1] = placeholder[target=input]
%aten_elu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.elu.default](args = (%input_1, 1.6732632423543772, 1.0507009873554805), kwargs = {})
return (aten_elu_default,)
@register_torch_op
def elu(context, node):

In executorch 0.7.0, torch.nn.SELU is decomposed to aten.elu. But the elu operator implementation is missing scale support.

Fix source code assert

Fix source code assert

pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py::TestExecuTorchExamples::test_add
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py::TestExecuTorchExamples::test_add_mul
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py::TestExecuTorchExamples::test_linear
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py::TestExecuTorchExamples::test_mul
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py::TestExecuTorchExamples::test_softmax
Adapt torchao 0.12.0

Adapt torchao 0.12.0

pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py::TestTorchao::test_weight_only_quantization
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py::TestPytorchQuantizedOps::test_unpack_int4packed_by_mm_with_eye_matrix
pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py::TestPytorchQuantizedOps::test_weight_int4pack_mm
pytest -x coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py::test_weight_module_act_fusion
    @functools.wraps(func)
    def decorate_context(*args, **kwargs):
        with ctx_factory():
>           return func(*args, **kwargs)
E           TypeError: quantize_affine() got an unexpected keyword argument 'zero_point_domain'
                and (
                    torch.nn.modules.batchnorm.BatchNorm2d
>                   in [val[1] for val in node.meta["source_fn_stack"]]
                    or torch.nn.modules.batchnorm.BatchNorm1d
                    in [val[1] for val in node.meta["source_fn_stack"]]
                )
            ):
E           KeyError: 'source_fn_stack'

In torchao 0.12.0, torchao_quant.quantize_affine no longer accepts the zero_point_domain argument.

torchao.quantization.pt2e expects node.meta["source_fn_stack"]; ensure node.meta exists by adding a dummy placeholder when missing.

Add dead node elimination

Add dead node elimination

pytest -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestEinsum::test_ternary_einsum
NotImplementedError: Unsupported fx node or__3, kind or_

Some nodes in the FX graph are not used, but the graph validation does not ignore dead nodes and thus fails.

Related Issue

Closes #2610.

@TobyRoseman
Copy link
Collaborator

Thanks for the Pull Request @M-Quadra! It would be great to support a newer version of PyTorch.

The code changes look good to me.

I've kicked off a CI run:
https://gitlab.com/coremltools1/coremltools/-/pipelines/2222858268
Even if this passes, I'll likely need to do more internal testing before we can merge this change.

@M-Quadra
Copy link
Contributor Author

PyTorch on the x86_64 platform is older and behaves differently; I'll mark the test with pytest.xfail soon.

@TobyRoseman
Copy link
Collaborator

That's a whole lot of tests to xfail. I think you should look into why they're failing.

X86_64 machines are running a different version of PyTorch:

torch==2.2.0; platform_machine != "arm64"

That's likely the cause.

@M-Quadra
Copy link
Contributor Author

Failing test:

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestUnfold::test_unfold
E           AttributeError: 'function' object has no attribute 'AUTO'

The torch.export.Dim.AUTO feature is only available in higher versions.

Solution

Use platform.machine() == "x86_64" to differentiate older versions.

@TobyRoseman
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support PyTorch 2.8.0

2 participants