Skip to content

Commit 6cf31cf

Browse files
committed
Fixed the CI
1 parent 049b810 commit 6cf31cf

File tree

2 files changed

+29
-36
lines changed

2 files changed

+29
-36
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,15 @@ def index(
258258
else:
259259
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
260260

261-
if isinstance(dim_tensor_shape_mult_d1, TRTTensor):
262-
mult_d1 = convert_binary_elementwise(
263-
ctx,
264-
target,
265-
source_ir,
266-
name + f"_shape_{i}",
267-
trt.ElementWiseOperation.PROD,
268-
mult_d1,
269-
dim_tensor_shape_mult_d1,
270-
)
261+
mult_d1 = convert_binary_elementwise(
262+
ctx,
263+
target,
264+
source_ir,
265+
name + f"_shape_{i}",
266+
trt.ElementWiseOperation.PROD,
267+
mult_d1,
268+
dim_tensor_shape_mult_d1,
269+
)
271270

272271
concat_tensor_layer = ctx.net.add_concatenation(
273272
[

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,16 @@ def forward(self, x, y, z, a, b):
310310
{0: seq_len3},
311311
),
312312
)
313-
with torchtrt.dynamo.Debugger(
314-
log_level="debug",
315-
capture_fx_graph_after=["remove_num_users_is_0_nodes"],
316-
logging_dir="/home/profile/logging/moe",
317-
engine_builder_monitor=False,
318-
):
319-
trt_mod = torchtrt.dynamo.compile(
320-
ep,
321-
inputs,
322-
enabled_precisions={torch.float16},
323-
min_block_size=1,
324-
use_explicit_typing=False,
325-
use_fp32_acc=False,
326-
disable_tf32=True,
327-
)
313+
314+
trt_mod = torchtrt.dynamo.compile(
315+
ep,
316+
inputs,
317+
enabled_precisions={torch.float16},
318+
min_block_size=1,
319+
use_explicit_typing=False,
320+
use_fp32_acc=False,
321+
disable_tf32=True,
322+
)
328323
result = trt_mod(*inputs)
329324
assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
330325

@@ -350,17 +345,16 @@ def forward(self, source_tensor, indices_tensor, value_tensor):
350345
(source_tensor, indices_tensor, value_tensor),
351346
dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}),
352347
)
353-
with torchtrt.dynamo.Debugger(log_level="debug"):
354-
trt_engine = torchtrt.dynamo.compile(
355-
ep,
356-
inputs=(source_tensor, indices_tensor, value_tensor),
357-
enabled_precisions={torch.float32},
358-
min_block_size=1,
359-
use_explicit_typing=False,
360-
use_fp32_acc=False,
361-
disable_tf32=True,
362-
use_python_runtime=True,
363-
)
348+
trt_engine = torchtrt.dynamo.compile(
349+
ep,
350+
inputs=(source_tensor, indices_tensor, value_tensor),
351+
enabled_precisions={torch.float32},
352+
min_block_size=1,
353+
use_explicit_typing=False,
354+
use_fp32_acc=False,
355+
disable_tf32=True,
356+
use_python_runtime=True,
357+
)
364358
result = trt_engine(source_tensor, indices_tensor, value_tensor)
365359

366360
torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)