Description
🐞Describe the bug
Attempting to convert the pytorch based d2go maskrcnn model with FBNETV2C4Backbone (https://github.com/facebookresearch/d2go/blob/main/configs/mask_rcnn_fbnetv3a_C4.yaml) to coreml.
Conversion to MIL successful
Converting Frontend ==> MIL Ops: 100%|█████████▉| 1027/1029 [00:04<00:00, 217.11 ops/s]
with a number of torch ops needing to be added, and some with minor patches (I'll list these at the end).
However the model runs into an unknown error when compiling the model "compiler error: Unknown error in building network shapes."
I've noticed if I convert the model with the convert_to=mlprogram
that the error changes to "compiler error: Encountered an error while compiling a neural network model: in operation of type cond: cond must return same types from both branches"
Stack Trace
Translating MIL ==> NeuralNetwork Ops: 100%|██████████| 23/23 [00:00<00:00, 12615.27 ops/s]
Translating MIL ==> NeuralNetwork Ops: 100%|██████████| 1030/1030 [00:13<00:00, 76.75 ops/s]
/conversion/lib/python3.9/site-packages/coremltools/models/model.py:137: RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: Error compiling model: "compiler error: Unknown error in building network shapes.".
but changes to
RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: Error compiling model: "compiler error: Encountered an error while compiling a neural network model: in operation of type cond: cond must return same types from both branches".
when using convert_to=mlprogram
when converting the model.
To Reproduce
I've not been able to isolate the layer causing this issue in the translation of MIL to NN, but would be very happy to dig deeper and try and find a minimal example. IF there are some suggestions on how I might be able to do this that would be rerally appreciated.
I'm running a basic conversion script
import coremltools as ct
import os
import torch
import torchvision
from d2go.model_zoo import model_zoo
from detectron2.export import TracingAdapter
import maskrcnn.ops
def inference_func(tmodel, image):
inputs= [{"image": image}]
return tmodel.inference(inputs, do_postprocess=False)[0]
def main():
model = model_zoo.get('mask_rcnn_fbnetv3a_C4.yaml', trained=True)
example = torch.rand(3, 224, 224)
wrapper = TracingAdapter(model, example, inference_func)
wrapper.eval()
traced_model = torch.jit.trace(wrapper, (example,))
inputs = [ct.ImageType(name="my_input", shape=(3, 224, 224))]
coreml_model = ct.convert(
traced_model,
inputs=inputs
)
if __name__ == '__main__':
main()
System environment (please complete the following information):
- python dependencies
- python=3.9
- pytorch=1.9.1
- torchvision=0.10.1
- coremltools=5.2.0
- MacOS 12.3.1
Additional context
added torch ops
nms
repeat_interleave
numel
roi_align
logicaland
ops with minor patches
clamp
narrow
index- handle broadcasting indexs
max - inputs are len == 1
split - handle case when num_splits == 1
to - handle the case where the dtype is not set, this should be inferred from the Tensor dtype. see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to
tupleunpack - handle case when len(output) == 1