-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
I am trying to work with the SSD300_VGG16 object detection model and convert it to a serialized and optimized format to eventually run on an embedded system. The script that I use to create and save the traced model from TorchScript is the following:
import torch
import torchvision
def do_trace(model, in_size=500):
model_trace = torch.jit.trace(model, torch.rand(1, 3, in_size, in_size))
model_trace.eval()
return model_trace
def dict_to_tuple(out_dict):
if "masks" in out_dict.keys():
return (out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"])
return (out_dict["boxes"], out_dict["scores"], out_dict["labels"])
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return dict_to_tuple(out[0])
model_funcs = [torchvision.models.detection.ssd300_vgg16]
names = ["ssd300_vgg16"]
for name, model_func in zip(names, model_funcs):
model = TraceWrapper(model_func(num_classes=50, pretrained_backbone=False))
model.eval()
in_size = 500
inp = torch.rand(1, 3, in_size, in_size)
with torch.no_grad():
out = model(inp)
script_module = do_trace(model)
script_out = script_module(inp)
assert len(out[0]) > 0 and len(script_out[0]) > 0
torch._C._jit_pass_inline(script_module.graph)
torch.jit.save(script_module, name + ".pt")
After that I would like to convert the saved model to another format that I am working with, but I am getting the error below:
File "/home/achalhoub/tvm/python/tvm/relay/frontend/pytorch.py", line 3091, in report_missing_conversion
raise NotImplementedError(msg)
NotImplementedError: The following operators are not implemented: ['aten::clamp_min', 'aten::copy_']
From the research that I've done on this, it seems like TVM doesn't support the conversion of some OPs in certain model architectures?
I would appreciate help with this issue!
Thanks.
I am using the following versions:
torch 1.10.1+cu102
torchvision 0.11.2+cu102
tvm 0.9.dev0