Skip to content

Conversation

M-Quadra
Copy link
Contributor

This PR is compatible with executorch torch 2.7.

Unit test

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestIndexPut::test_index_put_updates_bool

In torch 2.8.0, the error message is:

ValueError: In op, of type scatter_nd, named index_put, the named input `updates` must have the same data type as the named input `data`. However, updates has dtype fp32 whereas data has dtype int32.

Detail

import torch
import numpy as np
import coremltools as ct


class Model(torch.nn.Module):
    def forward(self, x):
        x = torch.ones(x.shape, dtype=torch.bool)
        y = torch.ones_like(x).bool()
        mask = torch.tensor([True, False, False, False, True, True]).view(3, 2)
        x[mask] = y[mask]
        return x


x = torch.randn(3, 2)
model = Model().eval()

exported_model = torch.export.export(model, (x,)).run_decompositions({})
mlmodel = ct.convert(
    exported_model,
    minimum_deployment_target=ct.target.iOS16,
)

y0 = model(x).numpy()
y1 = mlmodel.predict({"x": x.numpy()})["index_put"]
assert np.equal(y0, y1).all()
  • exported_model.graph in torch 2.7.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%ones_like,), kwargs = {dtype: torch.bool})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)
  • exported_model.graph in torch 2.8.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=2] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%ones_like, None, None, torch.bool), kwargs = {device: cpu, layout: torch.strided})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%ones_like, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)

In torch 2.8.0, the dtype of ones_like may be moved to _assert_tensor_metadata.

@TobyRoseman
Copy link
Collaborator

The max version of PyTorch that we currently support is 2.7.0 and that's what's being used in our CI system.

So there is no way to test this change other than making sure it doesn't break 2.7.0. As result, I'm reluctant to merge this change.

Would it be possible for you to look into what other changes are necessary for coremltools to support PyTorch 2.8.0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants