Skip to content

torch.fill_ can not apply after add function #1920

@fukatani

Description

@fukatani

🐞Describing the bug

  • torch.fill_ can not apply after add function

Maybe related to #1914 and we need more general solution.

Stack Trace

Model is not in eval mode. Consider calling '.eval()' on your model prior to conversion
Traceback (most recent call last):
  File "/Users/ryosukefukatani/work/coremltools/onth9.py", line 26, in <module>
    convert_to="neuralnetwork",
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/_converters_entry.py", line 542, in convert
    main_pipeline=pass_pipeline,
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 217, in _mil_convert
    **kwargs
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 61, in load
    specification_version,
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 335, in __init__
    p(self.graph)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/torchir_passes.py", line 151, in generate_tensor_assignment_ops
    raise ValueError("No matching select or slice.")
ValueError: No matching select or slice.

To Reproduce

import torch
import coremltools as ct
import numpy as np


class Net(torch.nn.Module):
    def forward(self, x):
        y = torch.empty(x.shape).to(torch.int32) + 1
        y.fill_(0.0)
        return y


x = torch.rand(2, 3)
traced_fn = torch.jit.trace(Net(), x)
ct_model = ct.convert(
    traced_fn,
    inputs=[
        ct.TensorType(
            shape=(
                ct.RangeDim(),
                ct.RangeDim(),
            )
        ),
    ],
    source="pytorch",
    convert_to="neuralnetwork",
)

out = traced_fn(x)
out_dict = ct_model.predict(
    {
        'x': x.detach().numpy().astype(np.float32),
    }
)
np.testing.assert_allclose(out, list(out_dict.values())[0], rtol=0.001, atol=0.001)

System environment (please complete the following information):

  • coremltools version: latest master

Metadata

Metadata

Assignees

No one assigned

    Labels

    PyTorch (traced)bugUnexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions