Skip to content

Commit

Permalink
add ONNX demo
Browse files Browse the repository at this point in the history
  • Loading branch information
wondervictor committed Apr 28, 2024
1 parent 0d1dd63 commit bc314a3
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 160 deletions.
152 changes: 0 additions & 152 deletions deploy/image-demo.py

This file was deleted.

128 changes: 128 additions & 0 deletions deploy/onnx_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import json
import argparse
import os.path as osp

import cv2
import numpy as np
import supervision as sv
import onnxruntime as ort
from mmengine.utils import ProgressBar

BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()


def parse_args():
parser = argparse.ArgumentParser('YOLO-World ONNX Demo')
parser.add_argument('onnx', help='onnx file')
parser.add_argument('image', help='image path, include image file or dir.')
parser.add_argument(
'text',
help=
'detecting texts (str, txt, or json), should be consistent with the ONNX model'
)
parser.add_argument('--output-dir',
default='./output',
help='directory to save output files')
parser.add_argument('--device',
default='cuda:0',
help='device used for inference')
args = parser.parse_args()
return args


def preprocess(image, size=(640, 640)):
h, w = image.shape[:2]
max_size = max(h, w)
scale_factor = size[0] / max_size
pad_h = (max_size - h) // 2
pad_w = (max_size - w) // 2
pad_image = np.zeros((max_size, max_size, 3), dtype=image.dtype)
pad_image[pad_h:h + pad_h, pad_w:w + pad_w] = image
image = cv2.resize(pad_image, size,
interpolation=cv2.INTER_LINEAR).astype('float32')
image /= 255.0
image = image[None]
return image, scale_factor, (pad_h, pad_w)


def visualize(image, bboxes, labels, scores, texts):
detections = sv.Detections(xyxy=bboxes, class_id=labels, confidence=scores)
labels = [
f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
zip(detections.class_id, detections.confidence)
]

image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
return image


def inference(ort_session, image_path, texts, output_dir, size=(640, 640)):
ori_image = cv2.imread(image_path)
h, w = ori_image.shape[:2]
image, scale_factor, pad_param = preprocess(ori_image[:, :, [2, 1, 0]],
size)
input_ort = ort.OrtValue.ortvalue_from_numpy(image.transpose((0, 3, 1, 2)))
results = ort_session.run(["num_dets", "labels", "scores", "boxes"],
{"images": input_ort})
num_dets, labels, scores, bboxes = results
num_dets = num_dets[0][0]
labels = labels[0, :num_dets]
scores = scores[0, :num_dets]
bboxes = bboxes[0, :num_dets]

bboxes -= np.array(
[pad_param[1], pad_param[0], pad_param[1], pad_param[0]])
bboxes /= scale_factor
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, w)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, w)
bboxes = bboxes.round().astype('int')

image_out = visualize(ori_image, bboxes, labels, scores, texts)
cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image_out)
return image_out


def main():

args = parse_args()
onnx_file = args.onnx
# init ONNX session
ort_session = ort.InferenceSession(
onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
print("Init ONNX Runtime session")
output_dir = "onnx_outputs"
if not osp.exists(output_dir):
os.mkdir(output_dir)

# load images
if not osp.isfile(args.image):
images = [
osp.join(args.image, img) for img in os.listdir(args.image)
if img.endswith('.png') or img.endswith('.jpg')
]
else:
images = [args.image]

if args.text.endswith('.txt'):
with open(args.text) as f:
lines = f.readlines()
texts = [[t.rstrip('\r\n')] for t in lines]
elif args.text.endswith('.json'):
texts = json.load(open(args.text))
else:
texts = [[t.strip()] for t in args.text.split(',')]

print("Start to inference.")
progress_bar = ProgressBar(len(images))
for img in images:
inference(ort_session, img, texts, output_dir=output_dir)
progress_bar.update()
print("Finish inference")


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions docs/deploy.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
## Deploy YOLO-World

- [x] ONNX export
- [x] ONNX demo
- [ ] TensorRT
- [ ] TFLite

We provide several ways to deploy YOLO-World with ONNX or TensorRT

### Priliminaries
Expand All @@ -24,6 +29,12 @@ You can also use [`export_onnx.py`](../deploy/export_onnx.py) to obtain the ONNX
PYTHONPATH=./ python deploy/export_onnx.py path/to/config path/to/weights --custom-text path/to/customtexts --opset 11
```

**Running ONNX demo**

```bash
python deploy/onnx_demo.py path/to/model.onnx path/to/images path/to/texts
```


### Export YOLO-World to TensorRT models

Expand Down
12 changes: 4 additions & 8 deletions requirements/onnx_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
onnx==1.15.0
onnxruntime==1.17.1
onnx-simplifier==0.4.33
onnx_graphsurgeon
simple_onnx_processing_tools
tensorflow==2.15.0
psutil==5.9.5
ml_dtypes==0.2.0
supervision
onnx
onnxruntime
onnxsim

0 comments on commit bc314a3

Please sign in to comment.