Skip to content

Commit 5770e00

Browse files
zewenli98gs-olive
authored andcommitted
feat: support cumsum dynamo converter (#2403)
1 parent 5b0e5fc commit 5770e00

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

+102-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import math
22
from typing import Optional
33

4+
import numpy as np
5+
import tensorrt as trt
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion import impl
69
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
get_positive_dim,
12+
get_trt_tensor,
13+
)
814
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
915
from torch_tensorrt.fx.converters.converter_utils import (
1016
has_dynamic_shape,
@@ -96,3 +102,98 @@ def expand(
96102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
97103
set_layer_name(layer, target, name, source_ir)
98104
return layer.get_output(0)
105+
106+
107+
def chunk(
108+
ctx: ConversionContext,
109+
target: Target,
110+
source_ir: Optional[SourceIR],
111+
name: str,
112+
input: TRTTensor,
113+
chunks: int,
114+
dim: int,
115+
) -> TRTTensor:
116+
if chunks <= 0:
117+
raise RuntimeError(
118+
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
119+
)
120+
121+
shape = input.shape
122+
dim = get_positive_dim(dim, len(shape))
123+
124+
if dim >= len(shape):
125+
raise RuntimeError(
126+
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
127+
)
128+
129+
dynamic_shape = has_dynamic_shape(input.shape)
130+
if dynamic_shape > 0:
131+
# Check whether slice target dim is dynamic shape dim
132+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
133+
134+
size_dim = shape[dim]
135+
chunk_size = math.ceil(size_dim / chunks)
136+
result = []
137+
start = 0
138+
end = min(start + chunk_size, size_dim)
139+
cnt = 0
140+
141+
while start < end:
142+
result.append(
143+
slice_op(
144+
ctx,
145+
target,
146+
source_ir,
147+
f"{name}_slice_{cnt}",
148+
input,
149+
dim,
150+
start,
151+
end,
152+
1,
153+
)
154+
)
155+
start = end
156+
end = min(start + chunk_size, size_dim)
157+
cnt += 1
158+
159+
return result
160+
161+
162+
def cumsum(
163+
ctx: ConversionContext,
164+
target: Target,
165+
source_ir: Optional[SourceIR],
166+
name: str,
167+
input: TRTTensor,
168+
dim: int,
169+
) -> TRTTensor:
170+
input_shape = input.shape
171+
dim = get_positive_dim(dim, len(input_shape))
172+
loop = ctx.net.add_loop()
173+
axis = np.array(input_shape[dim])
174+
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
175+
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
176+
iterator = loop.add_iterator(input, dim, reverse=False)
177+
data = iterator.get_output(0)
178+
new_dims = tuple(data.shape)
179+
zeros = np.zeros(new_dims)
180+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
181+
182+
running_sum = loop.add_recurrence(zero_trttensor)
183+
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
184+
running_sum_tensor = running_sum.get_output(0)
185+
186+
current_sum = impl.elementwise.add(
187+
ctx,
188+
target,
189+
source_ir,
190+
f"{name}_elementwise_add",
191+
data,
192+
running_sum_tensor,
193+
)
194+
running_sum.set_input(1, current_sum)
195+
196+
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
197+
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
198+
loop_output.set_input(1, trip_limit)
199+
return loop_output.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCumsumConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((1,), 0),
13+
((2,), 0),
14+
((3,), -1),
15+
]
16+
)
17+
def test_cumsum_1D(self, shape, dim):
18+
class Cumsum(nn.Module):
19+
def forward(self, x):
20+
return torch.ops.aten.cumsum.default(x, dim)
21+
22+
inputs = [torch.randn(shape)]
23+
self.run_test(
24+
Cumsum(),
25+
inputs,
26+
)
27+
28+
@parameterized.expand(
29+
[
30+
((3, 1), 0),
31+
((3, 1), 1),
32+
((2, 3), -1),
33+
((2, 3), -2),
34+
]
35+
)
36+
def test_cumsum_2D(self, shape, dims):
37+
class Cumsum(nn.Module):
38+
def forward(self, x):
39+
return torch.ops.aten.cumsum.default(x, dims)
40+
41+
inputs = [torch.randn(shape)]
42+
self.run_test(
43+
Cumsum(),
44+
inputs,
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
((4, 2, 3), 0),
50+
((4, 2, 3), 1),
51+
((1, 2, 3), 2),
52+
((1, 2, 3), -1),
53+
((1, 2, 3), -2),
54+
]
55+
)
56+
def test_cumsum_3D(self, shape, dims):
57+
class Cumsum(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.cumsum.default(x, dims)
60+
61+
inputs = [torch.randn(shape)]
62+
self.run_test(
63+
Cumsum(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)