Skip to content

Commit

Permalink
optimize topk (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Feb 18, 2024
1 parent c429009 commit 05c5fdc
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 198 deletions.
3 changes: 1 addition & 2 deletions tests/test_converters/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def input(self, batch, num):
def test_gather(self, input, dim, num):
model = _TestModel(num, dim).eval().cuda()
dummy_input = torch.zeros_like(input)
trt_model = module2trt(model,
args=[dummy_input])
trt_model = module2trt(model, args=[dummy_input])

with torch.inference_mode():
gt = model(input)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_converters/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def test_gather(self, input, dim, index):
model = _TestModel(dim)
dummy_input = torch.zeros_like(input)
dummy_index = torch.zeros_like(index)
trt_model = module2trt(model,
args=[dummy_input, dummy_index])
trt_model = module2trt(model, args=[dummy_input, dummy_index])

with torch.inference_mode():
gt = model(input, index)
Expand Down
93 changes: 93 additions & 0 deletions tests/test_converters/test_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import torch
from torch import nn
from torch2trt_dynamic import module2trt


class _TestStaticKModel(nn.Module):

def __init__(self, k, dim, largest) -> None:
super().__init__()
self.k = k
self.dim = dim
self.largest = largest

def forward(self, input):
val, index = input.topk(k=self.k, dim=self.dim, largest=self.largest)
return val, index


class _TestDynamicModel(nn.Module):

def __init__(self, k, dim, largest) -> None:
super().__init__()
self.k = k
self.dim = dim
self.largest = largest

def forward(self, input):
new_k = input.size(self.dim)
k = min(self.k, new_k)
val, index = input.topk(k=k, dim=self.dim, largest=self.largest)
return val, index


class TestTopk:

@pytest.fixture
def shape(self, request):
yield request.param

@pytest.fixture
def dim(self, request):
yield request.param

@pytest.fixture
def k(self, request):
yield request.param

@pytest.fixture
def largest(self, request):
yield request.param

@pytest.fixture
def input(self, shape):
yield torch.rand(shape).cuda()

@pytest.mark.parametrize('shape,dim', [
((5, 10), 0),
((5, 10), 1),
((5, ), 0),
])
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('largest', [True, False])
def test_static(self, input, k, dim, largest):
model = _TestStaticKModel(k, dim, largest)

dummy_input = torch.zeros_like(input)
trt_model = module2trt(model, args=[dummy_input])

with torch.inference_mode():
gt = model(input)
out = trt_model(input)
torch.testing.assert_close(out[0], gt[0])
torch.testing.assert_close(out[1].to(torch.int64), gt[1])

@pytest.mark.parametrize('shape,dim', [
((5, 10), 0),
((5, 10), 1),
((5, ), 0),
])
@pytest.mark.parametrize('k', [6])
@pytest.mark.parametrize('largest', [True, False])
def test_dynamic(self, input, k, dim, largest):
model = _TestDynamicModel(k, dim, largest)

dummy_input = torch.zeros_like(input)
trt_model = module2trt(model, args=[dummy_input])

with torch.inference_mode():
gt = model(input)
out = trt_model(input)
torch.testing.assert_close(out[0], gt[0])
torch.testing.assert_close(out[1].to(torch.int64), gt[1])
1 change: 1 addition & 0 deletions torch2trt_dynamic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .converters import * # noqa: F401,F403
from .torch2trt_dynamic import * # noqa: F401,F403
from .trt_module import TRTModule, TRTModuleMeta # noqa: F401, F403


def load_plugins():
Expand Down
83 changes: 53 additions & 30 deletions torch2trt_dynamic/converters/topk.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,79 @@
import tensorrt as trt
import torch
from torch2trt_dynamic.module_test import add_module_test
from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter,
trt_)
from torch2trt_dynamic.torch2trt_dynamic import (bind_arguments,
tensorrt_converter, trt_)

from .size import IntWarper


def _dummy_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
pass


@tensorrt_converter('torch.topk')
@tensorrt_converter('torch.Tensor.topk')
def convert_topk(ctx):
arguments = bind_arguments(_dummy_topk, ctx)
input = arguments['input']
k = arguments['k']
dim = arguments['dim']
largest = arguments['largest']

if dim is None:
dim = len(input.shape) - 1
if dim < 0:
dim = len(input.shape) + dim

def __add_unsqueeze_layer(input_trt, dim):
layer = ctx.network.add_shuffle(input_trt)
layer.reshape_dims = (1, ) + tuple(input_trt.shape)
input_trt = layer.get_output(0)
dim += 1
return input_trt, dim

def __add_topk_layer(k, dim):
topkOp = trt.TopKOperation.MAX if largest else trt.TopKOperation.MIN

k_trt = None
if isinstance(k, IntWarper):
k_trt = trt_(ctx.network, k)
layer = ctx.network.add_shuffle(k_trt)
layer.reshape_dims = tuple()
k_trt = layer.get_output(0)

input = ctx.method_args[0]
if isinstance(k, int) and k > 3840:
print('Clamp k to 3840.')
k = 3840

k = get_arg(ctx, 'k', pos=1, default=1)
axis = get_arg(ctx, 'dim', pos=2, default=len(input.shape) - 1)
if axis is None:
axis = len(input.shape) - 1
if axis < 0:
axis = len(input.shape) + axis
layer = ctx.network.add_topk(input_trt, topkOp, k, 1 << dim)

if k > 3840:
print('warning: topk = ' + k +
' > 3840 is not allowed in TensorRT, use 3840 instead.')
k = 3840
if k_trt is not None:
layer.set_input(1, k_trt)

largest = get_arg(ctx, 'largest', pos=3, default=True)
topkOp = trt.TopKOperation.MAX if largest else trt.TopKOperation.MIN
output0_trt = layer.get_output(0)
output1_trt = layer.get_output(1)
return output0_trt, output1_trt

def __add_squeeze_layer(output_trt):
layer = ctx.network.add_shuffle(output_trt)
layer.reshape_dims = tuple(output_trt.shape)[1:]
return layer.get_output(0)

input_trt = trt_(ctx.network, input)
output = ctx.method_return

# can only use topk on dim>=2
need_unsqueeze = len(input_trt.shape) == 1
if need_unsqueeze:
layer = ctx.network.add_shuffle(input_trt)
layer.reshape_dims = (1, ) + tuple(input_trt.shape)
input_trt = layer.get_output(0)
axis += 1

layer = ctx.network.add_topk(input_trt, topkOp, k, 1 << axis)
input_trt, dim = __add_unsqueeze_layer(input_trt, dim)

output0_trt = layer.get_output(0)
output1_trt = layer.get_output(1)
output0_trt, output1_trt = __add_topk_layer(k, dim)

# recovery
if need_unsqueeze:
layer = ctx.network.add_shuffle(output0_trt)
layer.reshape_dims = tuple(output0_trt.shape)[1:]
output0_trt = layer.get_output(0)

layer = ctx.network.add_shuffle(output1_trt)
layer.reshape_dims = tuple(output1_trt.shape)[1:]
output1_trt = layer.get_output(0)
output0_trt = __add_squeeze_layer(output0_trt)
output1_trt = __add_squeeze_layer(output1_trt)

output[0]._trt = output0_trt
output[1]._trt = output1_trt
Expand Down
Loading

0 comments on commit 05c5fdc

Please sign in to comment.