Skip to content

Error Saving YOLOv8 Model After CoreML Conversion #4512

Open
@gustavofuhr

Description

@gustavofuhr

🐛 Describe the bug

I'm trying to export the ultralytics YOLOv8 model using the CoreML backend, but I'm getting an error when saving the serialized lowered module.

Btw, I manage to save in the portable format.

Here's the code I'm using:

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.exir.backend.backend_api import to_backend

def generate_compile_specs_from_args(fp16 = False, compile = False):
    # model type will change depending on the compile option
    model_type = CoreMLBackend.MODEL_TYPE.MODEL
    if compile:
        model_type = CoreMLBackend.MODEL_TYPE.COMPILED_MODEL

    # precision can be FLOAT16 or FLOAT32
    compute_precision = ct.precision.FLOAT16 if fp16 else ct.precision.FLOAT32

    # compute_unit: sets where the model should run, CPU, GPU, NE (neural engine), all. 
    compute_unit = ct.ComputeUnit["ALL"] 

    return CoreMLBackend.generate_compile_specs(
        compute_precision=compute_precision,
        compute_unit=compute_unit,
        model_type=model_type,
    )

def lower_to_coreml_backend(to_be_lowered_module, fp16, compile):
    return to_backend(
        CoreMLBackend.__name__,
        to_be_lowered_module,
        generate_compile_specs_from_args(fp16, compile),
    )

model = YOLO(f"{MODEL_NAME}").model #.model is the actual model object

with torch.inference_mode():
    model.to("cpu")
    for p in model.parameters():
        p.requires_grad = False
    model.float()
    model.eval()

    model = model.fuse()
    y = None
    im = torch.zeros(1, 3, *MODEL_SIZE).to("cpu")

    for _ in range(2):
        y = model(im)  # dry runs

    example_args = (torch.randn(1, 3, *MODEL_SIZE),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_args)
    aten_dialect = torch.export.export(pre_autograd_aten_dialect, example_args)

    # Export and lower the module to Edge Dialect
    edge_program = to_edge(aten_dialect)
    
    # Lower the module to CoreML backend
    to_be_lowered_module = edge_program.exported_program()
    lowered_module = lower_to_coreml_backend(to_be_lowered_module, fp16=False, compile=False)
    
    # Serialize and save it to a file
    with open("models/yolo_executorch_coreml.pte", "wb") as file:
        file.write(lowered_module.buffer())

It's giving me the following error:

{
	"name": "Exception",
	"message": "An error occurred when running the 'SpecPropPass' pass after the following passes: []",
	"stack": "---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_manager.py:271, in PassManager.__call__(self, module)
    270 try:
--> 271     res = fn(module)
    273     if not isinstance(res, PassResult) and not hasattr(
    274         res, \"graph_module\"
    275     ):

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_base.py:41, in PassBase.__call__(self, graph_module)
     40 self.requires(graph_module)
---> 41 res = self.call(graph_module)
     42 self.ensures(graph_module)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:572, in _ExportPassBase.call(self, graph_module)
    571 with fake_tensor_mode, dispatcher_mode:  # type: ignore[assignment, union-attr]
--> 572     result = self.call_submodule(graph_module, tuple(inputs))
    574 return result

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:658, in ExportPass.call_submodule(self, graph_module, inputs)
    655 def call_submodule(
    656     self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
    657 ) -> PassResult:
--> 658     res = super().call_submodule(graph_module, inputs)
    660     def preserve_original_ph_meta_val(
    661         gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule
    662     ) -> None:

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:535, in _ExportPassBase.call_submodule(self, graph_module, inputs)
    534 with fx_traceback.preserve_node_meta():
--> 535     interpreter.run(*inputs_data)
    537 new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/interpreter.py:146, in Interpreter.run(self, initial_env, enable_io_processing, *args)
    145 try:
--> 146     self.env[node] = self.run_node(node)
    147 except Exception as e:

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:375, in _ExportPassBase.ExportInterpreter.run_node(self, n)
    374 self.callback.node_debug_str = n.format_node()
--> 375 return super().run_node(n)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/interpreter.py:203, in Interpreter.run_node(self, n)
    202 assert isinstance(kwargs, dict)
--> 203 return getattr(self, n.op)(n.target, args, kwargs)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:605, in ExportPass.ExportInterpreter.call_function(self, target, args, kwargs)
    604     value, key = args
--> 605     return self.callback.call_getitem(value, key, meta)
    606 elif isinstance(target, EdgeOpOverload):

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/passes/spec_prop_pass.py:100, in SpecPropPass.call_getitem(self, value, key, meta)
     99 meta[\"spec\"] = value.node.meta[\"spec\"][key]
--> 100 return super().call_getitem(value, key, meta)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:517, in _ExportPassBase.call_getitem(self, value, key, meta)
    514 def call_getitem(
    515     self, value: ProxyValue, key: int, meta: NodeMetadata
    516 ) -> ProxyValue:
--> 517     return self._fx(\"call_function\", operator.getitem, (value, key), {}, meta)

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py:397, in _ExportPassBase._fx(self, kind, target, args, kwargs, meta)
    394 args_data, kwargs_data = pytree.tree_map_only(
    395     ProxyValue, lambda x: x.data, (args, kwargs)
    396 )
--> 397 res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
    398 args_proxy, kwargs_proxy = pytree.tree_map_only(
    399     ProxyValue, lambda x: x.proxy, (args, kwargs)
    400 )

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/interpreter.py:275, in Interpreter.call_function(self, target, args, kwargs)
    274 # Execute the function and return the result
--> 275 return target(*args, **kwargs)

IndexError: tuple index out of range

While executing %getitem_25 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 2), kwargs = {})
Original traceback:
None

The above exception was the direct cause of the following exception:

Exception                                 Traceback (most recent call last)
Cell In[24], line 41
     39 # Serialize and save it to a file
     40 with open(\"models/yolo_executorch_coreml.pte\", \"wb\") as file:
---> 41     file.write(lowered_module.buffer())

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/lowered_backend_module.py:149, in LoweredBackendModule.buffer(self, extract_delegate_segments, segment_alignment, constant_tensor_alignment, delegate_alignment)
    143 \"\"\"
    144 Returns a buffer containing the serialized ExecuTorch binary.
    145 \"\"\"
    146 # TODO(T181463742): avoid calling bytes(..) which incurs large copies.
    147 out = bytes(
    148     _serialize_pte_binary(
--> 149         program=self.program(),
    150         extract_delegate_segments=extract_delegate_segments,
    151         segment_alignment=segment_alignment,
    152         constant_tensor_alignment=constant_tensor_alignment,
    153         delegate_alignment=delegate_alignment,
    154     )
    155 )
    156 return out

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/lowered_backend_module.py:322, in LoweredBackendModule.program(self, emit_stacktrace)
    302 # Double check the ExportedProgram data(especially everything except graph) is good
    303 exported_program = ExportedProgram(
    304     root=lowered_exported_program.graph_module,
    305     graph=lowered_exported_program.graph,
   (...)
    320     verifier=lowered_exported_program.verifier,
    321 )
--> 322 exported_program = _transform(
    323     exported_program, SpecPropPass(), MemoryPlanningPass(\"greedy\")
    324 )
    325 emitted_program = emit_program(
    326     exported_program, emit_stacktrace=emit_stacktrace
    327 ).program
    328 return emitted_program

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py:179, in _transform(self, *passes)
    177 def _transform(self, *passes: PassType) -> \"ExportedProgram\":
    178     pm = PassManager(list(passes))
--> 179     res = pm(self.graph_module)
    180     transformed_gm = res.graph_module if res is not None else self.graph_module
    181     assert transformed_gm is not None

File ~/projects/object_detection_ios_comprehensive/yolov5_yolov8_ultralytics_to_executorch/.executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_manager.py:297, in PassManager.__call__(self, module)
    292         prev_pass_names = [
    293             p.__name__ if inspect.isfunction(p) else type(p).__name__
    294             for p in self.passes[:i]
    295         ]
    296         msg = f\"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}\"
--> 297         raise Exception(msg) from e  # noqa: TRY002
    299 # If the graph no longer changes, then we can stop running these passes
    300 overall_modified = overall_modified or modified

Exception: An error occurred when running the 'SpecPropPass' pass after the following passes: []"
}

Versions

Collecting environment information...
PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A

Python version: 3.10.14 (main, Mar 19 2024, 21:46:16) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3

Versions of relevant libraries:
[pip3] executorch==0.3.0a0+ca8e0d2
[pip3] executorchcoreml==0.0.1
[pip3] numpy==1.26.4
[pip3] torch==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.19.0
[conda] Could not collect

cc @kimishpatel @YifanShenSZ @cymbalrush

Metadata

Metadata

Assignees

Labels

module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/need-user-inputThe issue needs more information from the reporter before moving forwardtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions