Skip to content

Commit

Permalink
[Torch] Fix PyTorch NMS conversion for negative scores (apache#7137)
Browse files Browse the repository at this point in the history
* Fix pytorch nms conversion for negative scores

* updated mask rcnn test to verify outputs and also run cuda target

* set rpn_post_nms_top_n_test to 200

* fix parameter name

* dump output box information

* simplifying
  • Loading branch information
masahi authored and trevor-m committed Jan 21, 2021
1 parent 87673b8 commit 3f2e8e7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 45 deletions.
14 changes: 8 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,16 +1857,18 @@ def nms(self, inputs, input_types):
scores = inputs[1]
iou_threshold = inputs[2]

num_boxes = _op.shape_of(scores)

# TVM NMS assumes score > 0
scores = scores - _op.min(scores) + _op.const(1.0)
# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})

# Prepare input data for get_valid_counts
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# Leverage get_valid_counts to sort the data and clear invalid boxes
ct, data, indices = get_relay_op("get_valid_counts")(
data, score_threshold=-1.0, id_index=-1, score_index=0
)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
ct = num_boxes

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,10 +1675,10 @@ def _gen_rand_inputs(num_boxes):
boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5
boxes[:, 2] += boxes[:, 0]
boxes[:, 3] += boxes[:, 1]
scores = torch.rand(num_boxes, dtype=torch.float)
scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32))
return boxes, scores

targets = ["llvm"] # dynamic nms does not work on gpu
targets = ["llvm", "cuda"]

for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
Expand Down
69 changes: 32 additions & 37 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tvm

import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download
Expand Down Expand Up @@ -70,7 +71,7 @@ def generate_jit_model(index):
]

model_func = model_funcs[index]
model = TraceWrapper(model_func(pretrained=True))
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
Expand All @@ -94,46 +95,40 @@ def test_detection_models():
download(img_url, img)

input_shape = (1, 3, in_size, in_size)
target = "llvm"

input_name = "input0"
shape_list = [(input_name, input_shape)]
score_threshold = 0.9

scripted_model = generate_jit_model(1)
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.cpu()
vm = VirtualMachine(vm_exec, ctx)
data = process_image(img)
pt_res = scripted_model(data)
data = data.detach().numpy()
vm.set_input("main", **{input_name: data})
tvm_res = vm.run()

# Note: due to accumulated numerical error, we can't directly compare results
# with pytorch output. Some boxes might have a quite tiny difference in score
# and the order can become different. We just measure how many valid boxes
# there are for input image.
pt_scores = pt_res[1].detach().numpy().tolist()
tvm_scores = tvm_res[1].asnumpy().tolist()
num_pt_valid_scores = num_tvm_valid_scores = 0

for score in pt_scores:
if score >= score_threshold:
num_pt_valid_scores += 1
else:
break

for score in tvm_scores:
if score >= score_threshold:
num_tvm_valid_scores += 1
else:
break

assert num_pt_valid_scores == num_tvm_valid_scores, (
"Output mismatch: Under score threshold {}, Pytorch has {} valid "
"boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores)
)
data_np = data.detach().numpy()

with torch.no_grad():
pt_res = scripted_model(data)

for target in ["llvm", "cuda"]:
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)

vm.set_input("main", **{input_name: data_np})
tvm_res = vm.run()

# Bounding boxes
tvm.testing.assert_allclose(
pt_res[0].cpu().numpy(), tvm_res[0].asnumpy(), rtol=1e-5, atol=1e-5
)
# Scores
tvm.testing.assert_allclose(
pt_res[1].cpu().numpy(), tvm_res[1].asnumpy(), rtol=1e-5, atol=1e-5
)
# Class ids
np.testing.assert_equal(pt_res[2].cpu().numpy(), tvm_res[2].asnumpy())

score_threshold = 0.9
print("Num boxes:", pt_res[0].cpu().numpy().shape[0])
print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold))

0 comments on commit 3f2e8e7

Please sign in to comment.