Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Batch Support for TRT #6955

Merged
merged 34 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move test to tensorrt.py
  • Loading branch information
Ubuntu committed Nov 25, 2020
commit 230c1257c18b32134b63212dd6a06193d6088614
99 changes: 99 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from tvm.contrib import graph_runtime, utils
from tvm.runtime.vm import VirtualMachine
from tvm.relay import Any, GlobalVar, transform
from typing import Dict, Tuple, Union
from tvm.contrib.download import download
import cv2
from tvm.relay.op.contrib import tensorrt


def skip_codegen_test():
Expand Down Expand Up @@ -1034,5 +1038,100 @@ def set_func_attr(func, compile_name, symbol_name):
tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)


def test_maskrcnn_resnet50() -> None:
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
"""
This function tests the working of pytorch maskrcnn with resnet50 as backbone with
VM and VM + TRT. Since the order of compiled model outputs is a bit different from
original pytorch model, it uses a custom logic for comparison check.
"""
if skip_codegen_test() or skip_runtime_test():
return

class TraceWrapper(torch.nn.Module):
"""
This class is a wrapper over the torch module to convert the outputs into traceable form
"""

def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model

def forward(
self, inp: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model(inp)
return out[0]["boxes"], out[0]["scores"], out[0]["labels"], out[0]["masks"]

def get_traced_maskrcnn_model(np_sample_input: np.ndarray) -> torch.jit.TopLevelTracedModule:
"""
This function takes a sample input and returns the traced maskrcnn model
"""
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=np_sample_input.shape))

with torch.no_grad():
out = model(inp)
traced_module = torch.jit.trace(model, inp)
traced_module.eval()

return traced_module

def get_maskrcnn_input(in_size: int) -> np.ndarray:
"""
This function gets a real image with multiple objects of interest and returns it.
"""
input_shape = (1, 3, in_size, in_size)
img_path = "test_street_small.jpg"
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/"
"master/gluoncv/detection/street_small.jpg"
)
download(img_url, img_path)

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

return img

in_size = 300
np_sample_input = get_maskrcnn_input(in_size)
traced_module = get_traced_maskrcnn_model(np_sample_input)
vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm")
ctx = tvm.cpu()
vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, ctx)
vm.set_input("main", **{"input0": np_sample_input})
tvm_res = vm.run()

# Descending sort by scores and get the high confidence indices. In this example 9 is chosen,
# because this image has 9 boxes over 0.9 confidence
num_high_confidence_boxes = 9
tvm_indices = np.argsort(-1 * tvm_res[1].asnumpy())[:num_high_confidence_boxes]

with torch.no_grad():
out = traced_module(torch.Tensor(np_sample_input))
# Descending sort by scores and get the high confidence indices
pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes]

tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol]
# Because of certain ops, there are certain minor differences in TVM outputs and PT outputs,
# This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around
# this is to test it on an entire dataset and compare mAP with the original model.
# However, since that is not practically possible on CI, the following compromise is made.
# These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g:
# 0.1 pixel difference of a box in a 300X300 image wont make any change.
for i, tol_val in zip(range(4), tol):
np.testing.assert_allclose(
tvm_res[i].asnumpy()[tvm_indices],
out[i].numpy()[pt_indices],
rtol=tol_val,
atol=tol_val,
)


if __name__ == "__main__":
pytest.main([__file__])
94 changes: 0 additions & 94 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3384,99 +3384,6 @@ def convert_traced_model_to_vm_trt(
return vm_trt_exec


def test_maskrcnn_resnet50() -> None:
"""
This function tests the working of pytorch maskrcnn with resnet50 as backbone with
VM and VM + TRT. Since the order of compiled model outputs is a bit different from
original pytorch model, it uses a custom logic for comparison check.
"""

class TraceWrapper(torch.nn.Module):
"""
This class is a wrapper over the torch module to convert the outputs into traceable form
"""

def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model

def forward(
self, inp: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model(inp)
return out[0]["boxes"], out[0]["scores"], out[0]["labels"], out[0]["masks"]

def get_traced_maskrcnn_model(np_sample_input: np.ndarray) -> torch.jit.TopLevelTracedModule:
"""
This function takes a sample input and returns the traced maskrcnn model
"""
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=np_sample_input.shape))

with torch.no_grad():
out = model(inp)
traced_module = torch.jit.trace(model, inp)
traced_module.eval()

return traced_module

def get_maskrcnn_input(in_size: int) -> np.ndarray:
"""
This function gets a real image with multiple objects of interest and returns it.
"""
input_shape = (1, 3, in_size, in_size)
img_path = "test_street_small.jpg"
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/"
"master/gluoncv/detection/street_small.jpg"
)
download(img_url, img_path)

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

return img

in_size = 300
np_sample_input = get_maskrcnn_input(in_size)
traced_module = get_traced_maskrcnn_model(np_sample_input)
vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm")
ctx = tvm.cpu()
vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, ctx)
vm.set_input("main", **{"input0": np_sample_input})
tvm_res = vm.run()

# Descending sort by scores and get the high confidence indices. In this example 9 is chosen,
# because this image has 9 boxes over 0.9 confidence
num_high_confidence_boxes = 9
tvm_indices = np.argsort(-1 * tvm_res[1].asnumpy())[:num_high_confidence_boxes]

with torch.no_grad():
out = traced_module(torch.Tensor(np_sample_input))
# Descending sort by scores and get the high confidence indices
pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes]

tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol]
# Because of certain ops, there are certain minor differences in TVM outputs and PT outputs,
# This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around
# this is to test it on an entire dataset and compare mAP with the original model.
# However, since that is not practically possible on CI, the following compromise is made.
# These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g:
# 0.1 pixel difference of a box in a 300X300 image wont make any change.
for i, tol_val in zip(range(4), tol):
np.testing.assert_allclose(
tvm_res[i].asnumpy()[tvm_indices],
out[i].numpy()[pt_indices],
rtol=tol_val,
atol=tol_val,
)


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -3621,7 +3528,6 @@ def get_maskrcnn_input(in_size: int) -> np.ndarray:

test_segmentaton_models()
test_3d_models()
test_maskrcnn_resnet50()

# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules
Expand Down