Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolov5 to onnx model with nms #159

Closed
trungpham2606 opened this issue Sep 5, 2021 · 23 comments · Fixed by #193
Closed

yolov5 to onnx model with nms #159

trungpham2606 opened this issue Sep 5, 2021 · 23 comments · Fixed by #193
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@trungpham2606
Copy link

Hello @zhiqwang
I wonder that did you finish exporting the yolov5 with nms to onnx model yet ?
I dont see any PR in yolov5-ultralystic.

@trungpham2606 trungpham2606 added the enhancement New feature or request label Sep 5, 2021
@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

Hi @trungpham2606

Actually in notebook https://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/export-onnx-inference-onnxruntime.ipynb, the exported ONNX model already contains the nms operator, and the key difference is the implementation of the following post-processing (It performs the same task as non_max_suppression except for the format of the input). If you can move this post-processing to ultralytics/yolov5, the exported ONNX model should contain the nms operator, I'll try to write a more detailed example for this later.

https://github.com/zhiqwang/yolov5-rt-stack/blob/cc2bd50978b7118ae1cb16918248d991d0b927e8/yolort/models/box_head.py#L124-L197

@trungpham2606
Copy link
Author

trungpham2606 commented Sep 5, 2021

@zhiqwang
Actually I had tried your code to export the onnx model, but I always meet the error about the mismatched shape.
this is the script I used:

class Warp_nms(torch.nn.Module):
    def __init__(self, score_thresh, nms_thresh, detection_per_img):
        super().__init__()
        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detection_per_img = detection_per_img

    def forward(self, dump_rois):
        detections: List[Dict[str, torch.Tensor]] = []
        xc = dump_rois[:, 4] > self.score_thresh
        x = dump_rois[xc]
        x[:, 5:] *= x[:, 4:5]

        box = xywh2xyxy_torch(x[:, :4])
        conf, j = x[:, 5:].max(1, keepdim=True)
        x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > self.score_thresh]

        _, index = x[:, 4].sort(descending=True)
        x = x[index][:30000]
        # Batched NMS
        c = x[:, 5:6] * 4096
        boxes, scores = x[:, :4] + c, x[:, 4]
        i = torchvision.ops.nms(boxes, scores, self.nms_thresh)  # NMS
        detections.append({'dets': x[i[:self.detection_per_img]]})
        return detections

and the error i met is
image

@trungpham2606
Copy link
Author

@zhiqwang
The shape (92,1) is the shape in: x[i[:self.detection_per_img]]. I think the shape is set as static. I cant get rid of it although I had set the output shape is dynamic as yours.

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

Hi @trungpham2606

BTW, The difference between nms and batched_nms in post-processing is not significant, I think you can follow my processing here first, and later I will try to see if I can call this part of the code directly.

@trungpham2606
Copy link
Author

@zhiqwang
Oh i will try with batched_nms instead and report to you the result later.
Thank you!

@trungpham2606
Copy link
Author

trungpham2606 commented Sep 5, 2021

Dear @zhiqwang
I had tried batched_nms as bellow:

class Warp_nms(torch.nn.Module):
    def __init__(self, score_thresh, nms_thresh, detection_per_img):
        super().__init__()
        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detection_per_img = detection_per_img

    def forward(self, dump_rois):
        detections: List[Dict[str, torch.Tensor]] = []
        xc = dump_rois[:, 4] > self.score_thresh
        x = dump_rois[xc]
        x[:, 5:] *= x[:, 4:5]

        box = xywh2xyxy_torch(x[:, :4])
        conf, j = x[:, 5:].max(1, keepdim=True)

        #for batched_nms
        boxes = box[conf.view(-1) > self.score_thresh]
        classes = j.float()[conf.view(-1) > self.score_thresh]
        scores = conf[[conf.view(-1) > self.score_thresh]]

        _, index = scores.sort(descending=True)

        boxes = boxes[index][:30000].view(-1, 4)
        classes = classes[index][:30000].view(-1)
        scores = scores[index][:30000].view(-1)

        # Batched NMS
        i = torchvision.ops.batched_nms(boxes, scores, classes, self.nms_thresh)  # NMS
        print('here is nms')
        # output = x[i[:300]]
        detections.append({'dets': x[i[:self.detection_per_img]]})
        return detections

during exporting progress, I met 2 warnings:
1.UserWarning: Exporting aten::index operator with indices of type Byte. Only 1-D indices are supported. In any other case, this will produce an incorrect ONNX graph.
warnings.warn("Exporting aten::index operator with indices of type Byte. "
2.UserWarning: This model contains a squeeze operation on dimension 1 on an input with unknown shape. Note that if the size of dimension 1 of the input is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on non-singleton dimensions, it is recommended to export this model using opset version 11 or higher.
"version 11 or higher.")

For the second warning, I had used opset 12.

After exporting successfully, The error still be the same as previous try.

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

Hi @trungpham2606

You can also check this notebook https://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/how-to-align-with-ultralytics-yolov5.ipynb as reference.

And we're welcome for contributing to combine the ultralytics/yolov5's backbone with yolort's post-processing to solve this ONNX nms exporting problem here.

@trungpham2606
Copy link
Author

trungpham2606 commented Sep 5, 2021

@zhiqwang
Can I ask what are the numbers here for ?
image
is that for a list with 3 elements inside ?

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

Is that for a list with 3 elements inside ?

They represent the three dynamic outputs of scores, labels and boxes respectively.

@trungpham2606
Copy link
Author

Hello @zhiqwang
I think the problem is about the batched_nms. I was using batched_nms from torchvisioni.ops not from box_ops as your.
Now I can successfully export to the onnx model and load the exported model with onnx-simple.
I'll execute to see the performance and tell you the results later.
Thank you so much for your support !

@trungpham2606
Copy link
Author

@zhiqwang
When the detection has different shape. it will raise a warning like this:
2021-09-05 13:35:40.7281422 [W:onnxruntime:, execution_frame.cc:721 onnxruntime::ExecutionFrame::VerifyOutputSizes] Expected shape from model of {6,4} does not match actual shape of {1,4} for output boxes

I know It's just a notification, but do you know how to get rid of this warning ?

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

Expected shape from model of {6,4} does not match actual shape of {1,4} for output boxes

Hi @trungpham2606 , I guess that you can use the dynamic shape as something like below (the parameter dynamic_axes is important here):

torch.onnx.export(
    model,
    (images,),
    export_onnx_name,
    do_constant_folding=True,
    opset_version=_onnx_opset_version,
    dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
    input_names=["images_tensors"],
    output_names=["scores", "labels", "boxes"],
)

I was using batched_nms from torchvisioni.ops not from box_ops as your. Now I can successfully export to the onnx model and load the exported model with onnx-simple.

BTW, the box_ops here is exactly from torchvisioni.ops:

https://github.com/zhiqwang/yolov5-rt-stack/blob/cc2bd50978b7118ae1cb16918248d991d0b927e8/yolort/models/box_head.py#L6
and
https://github.com/zhiqwang/yolov5-rt-stack/blob/cc2bd50978b7118ae1cb16918248d991d0b927e8/yolort/models/box_head.py#L190

@trungpham2606
Copy link
Author

trungpham2606 commented Sep 5, 2021

@zhiqwang
I used the same export code as your.
But when I changed to this:

torch.onnx.export(
    model,
    hm,
    f=ONXX_FILE_PATH,
    input_names=['image1'],
    output_names=['scores', 'classes', 'boxes', 'mask_features', 'kpts_features'],
    verbose=False,
    opset_version=11,
    do_constant_folding=True,
    # dynamic_axes= {"outputs": [0, 1, 2, 3, 4]},
    dynamic_axes= {
        'scores': {0: 'sequence'},
        'classes': {0: 'sequence'},
        'boxes': {0: 'sequence'},
        'mask_features': {0: 'sequence'},
        'kpts_features': {0: 'sequence'},
    },
)

The warnings disappeared then.
Anw, thank you for helping me ^^

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 5, 2021

The warnings disappeared then.

Hi @trungpham2606 , Congratulations!

@zhiqwang zhiqwang added the help wanted Extra attention is needed label Sep 5, 2021
@trungpham2606
Copy link
Author

trungpham2606 commented Sep 6, 2021

Hi @zhiqwang
I have question about the performance of onnx yolov5.
Have you made a comparison between your yolov5 and onnx yolov5 yet ?

I did see that the results from CPUExecutionProvider and CUDAExecutionProvider are different and the results from CPU execution are much more stable than the CUDA one.

@zhiqwang
Copy link
Owner

zhiqwang commented Sep 6, 2021

Seems that more information is needed to determine the reason for this problem. And to keep this thread clean, I think it's better to file a new discussion about this.

EDIT: The newly uploaded discussion is #160.

@zhiqwang
Copy link
Owner

zhiqwang commented Oct 8, 2021

FYI, using the following snippet will export a dynamic batch/shape ONNX model containing YOLOv5 model and post-processing (nms).

# 'yolov5s.pt' is downloaded from https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt
python tools/export_model.py --checkpoint_path yolov5s.pt --skip_preprocess

Check out the details in #193.

I believe this can resolve this problem, and as such I'm closing this issue, feel free to create another ticket if you have more question.

@zhiqwang zhiqwang closed this as completed Oct 8, 2021
@zhiqwang zhiqwang linked a pull request Oct 8, 2021 that will close this issue
@Deronjey
Copy link

Deronjey commented Apr 7, 2022

@zhiqwang I used the same export code as your. But when I changed to this:

torch.onnx.export(
    model,
    hm,
    f=ONXX_FILE_PATH,
    input_names=['image1'],
    output_names=['scores', 'classes', 'boxes', 'mask_features', 'kpts_features'],
    verbose=False,
    opset_version=11,
    do_constant_folding=True,
    # dynamic_axes= {"outputs": [0, 1, 2, 3, 4]},
    dynamic_axes= {
        'scores': {0: 'sequence'},
        'classes': {0: 'sequence'},
        'boxes': {0: 'sequence'},
        'mask_features': {0: 'sequence'},
        'kpts_features': {0: 'sequence'},
    },
)

The warnings disappeared then. Anw, thank you for helping me ^^

Hi @zhiqwang
i have some questions about here:
1.'move this post-processing to ultralytics/yolov5', which tags of ultralytics/yolov5 you used? and move the box_head.py to where?
2.how to use this project to trs my model.py to model.onnx
Sincerely hope to get your reply,Thank you

@zhiqwang
Copy link
Owner

zhiqwang commented Apr 7, 2022

@zhiqwang
Copy link
Owner

zhiqwang commented Apr 7, 2022

which tags of ultralytics/yolov5 you used?

@Deronjey , and we support version 3.1, 4.0 and 6.0 released by ultralytics/yolov5.

Actually the version 5.0 models released by yolov5 is same with 4.0, so you can just set upstream_version="r4.0" if you're using the 5.0, and 6.1 is same with 6.0, so you can set the upstream_version to "r6.0" if your model is trained by 6.1.
https://github.com/zhiqwang/yolov5-rt-stack/blob/a6a08dd4b6ac5c24bf0275cbc1701ff561274ae0/yolort/models/__init__.py#L32

@Deronjey
Copy link

Deronjey commented Apr 7, 2022

I used following commad to export the ONNX models, and I use the 5.0 tag of ultralytics/yolov5 to train the model.pt. It raises an AttributeError: conv object has no attribute weight. How can I do for this error?

python tools/export_model.py --checkpoint_path model.pt --size_divisible 32

@zhiqwang
Copy link
Owner

zhiqwang commented Apr 7, 2022

how can i do for this error? this has already perplexed me for a long time

Hi @Deronjey, You can add arguments --version r4.0

python3 tools/export_model.py --checkpoint_path model.pt --size_divisible 32 --version r4.0

@Deronjey
Copy link

Deronjey commented Apr 7, 2022

export_model

it's done,Thank you three thousand times

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants