Description
🐞Describing the bug
I am converting a pytorch layer into coreml, to measure if it will speed up inference compared to pure pytorch with mps.
I have traced my layer using torch.jit.trace, and all of the conversion process goes pretty fast, until it tries to create a _MLModelProxy instance, where it then hangs for 5-6 hours.
The problem occurs on line 144 of models/model.py
I have narrowed the problem down to the _MLModelProxy creation just by using breakpoints around each significant function and seeing if the program hangs on that function or not. I cannot trace any further as it seems to call lower level libraries from that point.
This seems like unintended behaviour, since the model itself is succesfully converted to the MIL backend without problem, and converting a traced layer with under 25 operations shouldnt take 5 hours of 100% CPU utilization time.
In my layer i do use einops.rearrange transformations, as well as torch.einsum, but they should just trace to torch.permute and torch.reshape operations, and shouldnt cause extreme compile times like this.
It should also be said that after the 5-6 hours, the layer does compile to an mlpackage which produces the correct output, so there is no functional bug, it is merely the time it takes which seems extreme.
Stack Trace
This is the console output before it hangs.
Saving value type of int64 into a builtin type of int32, might overflow or loses precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 689/690 [00:01<00:00, 604.40 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 14.95 passes/s]
Running MIL default pipeline: 100%|██████████| 57/57 [00:11<00:00, 4.85 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 10/10 [00:00<00:00, 338.07 passes/s]
This process is what is running while the program hangs. The image was taken after a couple of restarts, so the process has "only" run for an hour, though that is still much more than expected
To Reproduce
This is just a simplified version of my layer with the different operations included. It also hangs when compiled. The einops "allow_ops_in_compiled_graph" just tells pytorch to allow the operations in the graph even if they arent traceable, though they should just compile to standard pytorch operations in the end.
import torch
import torch.nn as nn
from einops import rearrange
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
import coremltools as ct
import numpy as np
def attn(q, k, v):
sim = torch.einsum('b i d, b j d -> b i j', q, k)
attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
return out
class ANEVarAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
h = self.num_heads
# project x to q, k, v values
q, k, v = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q *= self.scale
# splice out CLS token at index 1
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
# let CLS token attend to key / values of all patches across time and space
cls_out = attn(cls_q, k, v)
cls_out = rearrange(cls_out, '(b h) n d -> b n (h d)', h=h)
out = self.proj(cls_out)
return out
device = 'mps'
data = torch.rand((5,785,768)).to(device)
ane_layer = ANEVarAttention(768, 12).to(device)
ane_traced_layer = torch.jit.trace(ane_layer, (data))
np_data = data.cpu().detach().numpy()
ane_mlpackage_obj = ct.convert(
traced_ane_layer,
convert_to="mlprogram",
inputs=[
ct.TensorType(
f"x",
shape=np_data.shape,
dtype=np.float32,
)
],
debug=True,
compute_units=ct.ComputeUnit.ALL,
)
System environment (please complete the following information):
- coremltools version: 6.3.0
- OS (e.g. MacOS version or Linux type): Ventura 13.4.1, M1 16GB
- Any other relevant version information (e.g. PyTorch or TensorFlow version):
-- Pytorch: 2.0.1
-- einops: 0.6.1
-- numpy: 1.24.4