Skip to content

Commit

Permalink
Merge branch 'main' into jialli/nominmax
Browse files Browse the repository at this point in the history
  • Loading branch information
skyline75489 authored Dec 23, 2024
2 parents c2db68c + 729d1c6 commit e90c53f
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions onnxruntime_extensions/tools/add_pre_post_processing_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def superresolution(model_file: Path, output_file: Path, output_format: str, onn


def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jpg',
onnx_opset: int = 16, num_classes: int = 80, input_shape: List[int] = None):
onnx_opset: int = 16, num_classes: int = 80, input_shape: List[int] = None,
output_as_image: bool = True):
"""
SSD-like model and Faster-RCNN-like model are including NMS inside already, You can find it from onnx model zoo.
Expand All @@ -185,6 +186,7 @@ def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jp
:param onnx_opset: The opset version of onnx model, default(16).
:param num_classes: The number of classes, default(80).
:param input_shape: The shape of input image (height,width), default will be asked from model input.
:param output_as_image: The flag that means that the model should have the image with boxes instead of the coordinates of the boxess
"""
model = onnx.load(str(model_file.resolve(strict=True)))
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
Expand Down Expand Up @@ -284,19 +286,23 @@ def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jp
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
]),
# DrawBoundingBoxes on the original image
# Model imported from pytorch has CENTER_XYWH format
# two mode for how to color box,
# 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence)
(DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=num_classes, colour_by_classes=True),
[
utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0),
utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1),
]),
# Encode to jpg/png
ConvertBGRToImage(image_format=output_format),
]

if output_as_image:
post_processing_steps += [
# DrawBoundingBoxes on the original image
# Model imported from pytorch has CENTER_XYWH format
# two mode for how to color box,
# 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence)
(DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=num_classes, colour_by_classes=True),
[
utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0),
utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1),
]),
# Encode to jpg/png
ConvertBGRToImage(image_format=output_format),
]

pipeline.add_post_processing(post_processing_steps)

new_model = pipeline.run(model)
Expand Down

0 comments on commit e90c53f

Please sign in to comment.