Skip to content

Commit

Permalink
update NMS pattern following frontend change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 26, 2020
1 parent bcdf774 commit 4d43fdc
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
14 changes: 7 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,26 +1857,26 @@ def nms(self, inputs, input_types):
scores = inputs[1]
iou_threshold = inputs[2]

num_boxes = _op.shape_of(scores)

# TVM NMS assumes score > 0
scores = scores - _op.min(scores) + _op.const(1.0)

num_boxes = _op.shape_of(scores)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
ct = num_boxes

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
score_index = 0
top_k = max_out_size = -1
nms_ret = get_relay_op("non_max_suppression")(
data=data,
valid_count=ct,
valid_count=num_boxes,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_threshold,
Expand Down
48 changes: 26 additions & 22 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
is_op,
rewrite,
is_tuple,
is_tuple_get_item,
wildcard,
DFPatternCallback,
)
Expand All @@ -37,7 +36,7 @@ def is_version_greater_than(ver):
)


def batched_nms_pattern(boxes, scores, idxs, iou_threshold):
def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""A pattern to detect batched_nms function in torchvision
The inputs to this function, boxes, scores, idxs, iou_threshold are wildcard
Expand All @@ -53,7 +52,9 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
return keep
Here is how PyTorch frontend lowers above PyTorch code. For simplicity, Relay ops for
dealing with dynamic strided_slice are omitted.
dealing with dynamic strided_slice are omitted. %num_boxes, %indices are complex
expressions, but since we can use the wildcard part for them, we do not need to construct
their patterns.
%2 = expand_dims(%scores, axis=-1);
%3 = cast(%idxs, dtype="float32");
Expand All @@ -66,11 +67,9 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
%10 = (%2, %9);
%11 = concatenate(%10, axis=-1);
%12 = expand_dims(%11, axis=0);
%13 = vision.get_valid_counts(%12, -1f, meta[relay.attrs.GetValidCountsAttrs][0]);
%14 = %13.1;
%15 = %13.0;
%16 = %13.2;
%17 = vision.non_max_suppression(%14, %15, %16, -1, 0.7f, ...);
...
...
%17 = vision.non_max_suppression(%12, %num_boxes, %indices, -1, 0.7f, ...);
"""
one = is_constant()
Expand Down Expand Up @@ -106,15 +105,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
score_expand_dims = is_op("expand_dims")(scores)
tup = is_tuple([score_expand_dims, add])
concat = is_op("concatenate")(tup)
expand_dims = is_op("expand_dims")(concat)

get_valid_counts_out = is_op("vision.get_valid_counts")(expand_dims, is_constant())
data = is_tuple_get_item(get_valid_counts_out, 1)
valid_counts = is_tuple_get_item(get_valid_counts_out, 0)
indices = is_tuple_get_item(get_valid_counts_out, 2)
data = is_op("expand_dims")(concat)

return is_op("vision.non_max_suppression")(
data, valid_counts, indices, is_constant(), iou_threshold
data, num_boxes, indices, is_constant(), iou_threshold
)


Expand All @@ -128,22 +122,30 @@ def __init__(self):
self.scores = wildcard()
self.idxs = wildcard()
self.iou_threshold = wildcard()
self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold)
self.num_boxes = wildcard()
self.indices = wildcard()

self.pattern = batched_nms_pattern(
self.boxes,
self.scores,
self.idxs,
self.iou_threshold,
self.num_boxes,
self.indices,
)

def convert_batched_nms(self, boxes, scores, idxs, iou_thres):
def convert_batched_nms(self, boxes, scores, idxs, iou_thres, num_boxes, indices):
"""Restore class-aware NMS using extracted class indices"""
scores = op.expand_dims(scores, axis=-1, num_newaxis=1)
idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1)
idxs = op.cast(idxs, "float32")
data = op.concatenate([idxs, scores, boxes], -1)
data = op.expand_dims(data, 0, 1)
ct, data, indices = op.vision.get_valid_counts(
data, score_threshold=-1.0, id_index=0, score_index=1
)

top_k = max_out_size = -1
out = op.vision.non_max_suppression(
data=data,
valid_count=ct,
valid_count=num_boxes,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_thres,
Expand All @@ -162,7 +164,9 @@ def callback(self, pre, post, node_map):
scores = node_map[self.scores][0]
idxs = node_map[self.idxs][0]
iou_thres = node_map[self.iou_threshold][0]
return self.convert_batched_nms(boxes, scores, idxs, iou_thres)
num_boxes = node_map[self.num_boxes][0]
indices = node_map[self.indices][0]
return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)


def rewrite_nms_to_batched_nms(mod):
Expand Down
10 changes: 5 additions & 5 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,17 @@ def test_detection_models():
with torch.no_grad():
pt_res = scripted_model(data)

def compile_and_run_vm(mod, params, data_np):
with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
def compile_and_run_vm(mod, params, data_np, target):
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.cpu()
ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)
vm.set_input("main", **{input_name: data_np})
return vm.run()

for target in ["cuda", "llvm"]:
tvm_res = compile_and_run_vm(mod, params, data_np)
tvm_res = compile_and_run_vm(mod, params, data_np, target)

# Bounding boxes
tvm.testing.assert_allclose(
Expand All @@ -141,7 +141,7 @@ def compile_and_run_vm(mod, params, data_np):
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np)
tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")

# Results should be equivalent after rewriting
for res1, res2 in zip(tvm_res, tvm_res_after_rewrite):
Expand Down

0 comments on commit 4d43fdc

Please sign in to comment.