Skip to content

Commit 0c1481c

Browse files
authored
Add normalization of cat operator argument into valid input (#3890)
1 parent b447f08 commit 0c1481c

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

examples/distributed_inference/data_parallel_stable_diffusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,5 @@
5353

5454
# Assume there are 2 processes (2 devices)
5555
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
56-
print("before \n")
5756
result = pipe(prompt).images[0]
58-
print("after ")
5957
result.save(f"result_{distributed_state.process_index}.png")

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import operator
5-
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -218,21 +218,57 @@ def aten_ops_native_group_norm(
218218
)
219219

220220

221-
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
221+
def parse_cat_args(
222+
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
223+
) -> Tuple[List[Any], int]:
224+
"""
225+
Process inputs for torch.ops.aten.cat.default.
226+
227+
Handles these valid patterns:
228+
1. args = ((t1, t2, ...), dim)
229+
2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
230+
231+
Returns:
232+
(input_tensors, dim)
233+
input_tensors: tuple of tensor arguments
234+
dim: integer concatenation dimension (default 0)
235+
"""
236+
237+
if len(args) > 1 and isinstance(args[0], (list, tuple)):
238+
input_tensors = list(args[0])
239+
dim = args_bounds_check(args, 1, 0)
240+
241+
else:
242+
# If single arg is itself a tuple/list, unwrap it
243+
if len(args) == 1 and isinstance(args[0], (list, tuple)):
244+
input_tensors = list(args[0])
245+
else:
246+
input_tensors = list(args)
247+
248+
dim = kwargs.get("dim", 0)
249+
250+
return input_tensors, dim
251+
252+
253+
@dynamo_tensorrt_converter(
254+
torch.ops.aten.cat.default,
255+
supports_dynamic_shapes=True,
256+
)
222257
def aten_ops_cat(
223258
ctx: ConversionContext,
224259
target: Target,
225260
args: Tuple[Argument, ...],
226261
kwargs: Dict[str, Argument],
227262
name: str,
228263
) -> Union[TRTTensor, Sequence[TRTTensor]]:
264+
inputs, dim = parse_cat_args(args, kwargs)
229265
return impl.cat.cat(
230266
ctx,
231267
target,
232268
SourceIR.ATEN,
233269
name,
234-
input=args[0],
235-
dim=args_bounds_check(args, 1, 0),
270+
input=inputs,
271+
dim=dim,
236272
)
237273

238274

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,41 @@ def forward(self, x, y, z):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
("pos", 1),
31+
("neg", -2),
32+
]
33+
)
34+
def test_cat_dim_in_kwargs(self, _, dim):
35+
class Cat(nn.Module):
36+
def forward(self, x, y, z):
37+
return torch.ops.aten.cat.default((x, y, z), dim=dim)
38+
39+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
40+
self.run_test(
41+
Cat(),
42+
inputs,
43+
)
44+
45+
@parameterized.expand(
46+
[
47+
("pos", 0),
48+
("neg", -3),
49+
]
50+
)
51+
def test_cat_with_scalar_inputs(self, _, dim):
52+
# Ensure scalar tensor wrap works
53+
class Cat(nn.Module):
54+
def forward(self, x, y):
55+
# y is a scalar, x is a tensor
56+
return torch.ops.aten.cat.default((x, y), dim)
57+
58+
x = torch.randn(1, 2, 3, device="cuda")
59+
y = torch.ones_like(x) * 5.0 # simulate scalar broadcast
60+
inputs = [x, y]
61+
self.run_test(Cat(), inputs)
62+
2863
@parameterized.expand(
2964
[
3065
("pos", 1),

0 commit comments

Comments
 (0)