Skip to content

Commit

Permalink
update gather embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 8, 2024
1 parent 697fed8 commit 95d8e6d
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 97 deletions.
44 changes: 44 additions & 0 deletions tests/test_converters/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch
from torch import nn
from torch2trt_dynamic import module2trt


class _TestModel(nn.Module):

def __init__(self, num, dim) -> None:
super().__init__()
self.embeding = nn.Embedding(num, dim)

def forward(self, input):
return self.embeding(input)


class TestGather:

@pytest.fixture
def dim(self):
yield 4

@pytest.fixture
def num(self):
yield 10

@pytest.fixture
def batch(self):
yield 2

@pytest.fixture
def input(self, batch, num):
yield torch.randint(num, (batch, 6)).cuda()

def test_gather(self, input, dim, num):
model = _TestModel(num, dim).eval().cuda()
dummy_input = torch.zeros_like(input)
trt_model = module2trt(model,
args=[dummy_input])

with torch.inference_mode():
gt = model(input)
out = trt_model(input)
torch.testing.assert_close(out, gt)
43 changes: 43 additions & 0 deletions tests/test_converters/test_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch
from torch import nn
from torch2trt_dynamic import module2trt


class _TestModel(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim

def forward(self, input, index):
return torch.gather(input, self.dim, index)


class TestGather:

@pytest.fixture
def input(self):
yield torch.rand(3, 4, 5).cuda()

@pytest.fixture
def dim(self, request):
yield request.param

@pytest.fixture
def index(self, input, dim):
max_val = input.size(dim)
yield torch.randint(max_val, (3, 4, 5)).cuda()

@pytest.mark.parametrize('dim', [0, 1, 2])
def test_gather(self, input, dim, index):
model = _TestModel(dim)
dummy_input = torch.zeros_like(input)
dummy_index = torch.zeros_like(index)
trt_model = module2trt(model,
args=[dummy_input, dummy_index])

with torch.inference_mode():
gt = model(input, index)
out = trt_model(input, index)
torch.testing.assert_close(out, gt)
33 changes: 18 additions & 15 deletions torch2trt_dynamic/converters/Embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from ..plugins import create_torchembedding_plugin
import tensorrt as trt

from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_


def _update_weight(weight, max_norm, norm_type):
if max_norm is None:
return weight
num_embeddings = weight.shape[0]
for emb_id in range(num_embeddings):
norm = weight[emb_id].norm(norm_type)
if norm > max_norm:
scale = max_norm / (norm + 1e-7)
weight[emb_id] = weight[emb_id] * scale
return weight


@tensorrt_converter('torch.nn.Embedding.forward')
def convert_embedding_forward(ctx):
module = ctx.method_args[0]
Expand All @@ -22,24 +35,14 @@ def convert_embedding(ctx):
norm_type = get_arg(ctx, 'norm_type', pos=4, default=2)
output = ctx.method_return

if max_norm is not None:
num_embeddings = weight.shape[0]
for emb_id in range(num_embeddings):
norm = weight[emb_id].norm(norm_type)
if norm > max_norm:
scale = max_norm / (norm + 1e-7)
weight[emb_id] = weight[emb_id] * scale

weight = _update_weight(weight, max_norm, norm_type)
if padding_idx is not None:
weight[padding_idx, :] = 0

input_trt = trt_(ctx.network, input)
weight_trt = trt_(ctx.network, weight)

plugin = create_torchembedding_plugin(
'torch_gather_' + str(id(input)), weight=weight)

layer = ctx.network.add_plugin_v2(
inputs=[input_trt, weight_trt], plugin=plugin)
layer = ctx.network.add_gather_v2(weight_trt, input_trt,
trt.GatherMode.DEFAULT)
layer.axis = 0

output._trt = layer.get_output(0)
8 changes: 2 additions & 6 deletions torch2trt_dynamic/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from . import AdaptiveAvgPool2d # noqa: F401
from . import AdaptiveMaxPool2d # noqa: F401
from . import Embedding # noqa: F401
from . import adaptive_avg_pool1d # noqa: F401
from . import adaptive_avg_pool2d # noqa: F401
from . import adaptive_max_pool1d # noqa: F401
from . import adaptive_max_pool2d # noqa: F401
from . import add # noqa: F401
from . import gather # noqa: F401
from . import grid_sample # noqa: F401
from .activation import (convert_elu, convert_leaky_relu, convert_selu,
convert_softplus, convert_softsign)
Expand Down Expand Up @@ -340,8 +342,6 @@
from .cumprod import convert_cumprod
from .cumsum import convert_cumsum
from .deform_conv2d import convert_deform_conv2d
from .Embedding import convert_embedding, convert_embedding_forward
from .gather import convert_gather
from . import GroupNorm # noqa: F401
from .nms import convert_nms
from .roi_align import convert_roi_align, convert_RoiAlign
Expand All @@ -360,10 +360,6 @@
__all__ += ['convert_cumsum']
# deform_conv2d
__all__ += ['convert_deform_conv2d']
# Embedding
__all__ += ['convert_embedding', 'convert_embedding_forward']
# gather
__all__ += ['convert_gather']
# nms
__all__ += ['convert_nms']
# roi_align
Expand Down
12 changes: 6 additions & 6 deletions torch2trt_dynamic/converters/gather.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..plugins import create_torchgather_plugin
import tensorrt as trt

from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_


Expand All @@ -13,10 +14,9 @@ def convert_gather(ctx):
inputs_trt = trt_(ctx.network, inputs)
index_trt = trt_(ctx.network, index)

plugin = create_torchgather_plugin(
'torch_gather_' + str(id(inputs)), dim=dim)

layer = ctx.network.add_plugin_v2(
inputs=[inputs_trt, index_trt], plugin=plugin)
layer = ctx.network.add_gather_v2(inputs_trt, index_trt,
trt.GatherMode.ELEMENT)
layer.num_elementwise_dims = 0
layer.axis = dim

output._trt = layer.get_output(0)
11 changes: 4 additions & 7 deletions torch2trt_dynamic/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from .create_torchbmm_plugin import create_torchbmm_plugin
from .create_torchcum_plugin import create_torchcum_plugin
from .create_torchcummaxmin_plugin import create_torchcummaxmin_plugin
from .create_torchembedding_plugin import create_torchembedding_plugin
from .create_torchgather_plugin import create_torchgather_plugin
from .create_torchunfold_plugin import create_torchunfold_plugin
from .globals import load_plugin_library

__all__ = [
'create_groupnorm_plugin', 'create_torchgather_plugin',
'create_adaptivepool_plugin', 'create_torchcummaxmin_plugin',
'create_torchcum_plugin', 'create_dcn_plugin', 'create_nms_plugin',
'create_roiextractor_plugin', 'create_roipool_plugin',
'create_torchembedding_plugin', 'create_torchbmm_plugin',
'create_groupnorm_plugin', 'create_adaptivepool_plugin',
'create_torchcummaxmin_plugin', 'create_torchcum_plugin',
'create_dcn_plugin', 'create_nms_plugin', 'create_roiextractor_plugin',
'create_roipool_plugin', 'create_torchbmm_plugin',
'create_torchunfold_plugin'
]

Expand Down
22 changes: 0 additions & 22 deletions torch2trt_dynamic/plugins/create_repeatdim_plugin.py

This file was deleted.

25 changes: 0 additions & 25 deletions torch2trt_dynamic/plugins/create_torchembedding_plugin.py

This file was deleted.

16 changes: 0 additions & 16 deletions torch2trt_dynamic/plugins/create_torchgather_plugin.py

This file was deleted.

5 changes: 5 additions & 0 deletions torch2trt_dynamic/torch2trt_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
torch.bool: trt.bool,
torch.int8: trt.int8,
torch.int32: trt.int32,
torch.int64: trt.int32,
torch.float16: trt.float16,
torch.float32: trt.float32,
}
Expand Down Expand Up @@ -532,6 +533,10 @@ def _bind_inputs(self, *args, **kwargs):
inputs = self.signature.bind(*args, **kwargs).arguments
inputs = dict(
(name, tensor.contiguous()) for name, tensor in inputs.items())
for name, tensor in inputs.items():
if tensor.dtype == torch.int64:
tensor = tensor.to(torch.int32)
inputs[name] = tensor
self._check_input_shape(inputs)
return inputs

Expand Down

0 comments on commit 95d8e6d

Please sign in to comment.