Skip to content

Commit 5a779c6

Browse files
reenable back thor test (#3929)
1 parent ca0765c commit 5a779c6

File tree

12 files changed

+99
-96
lines changed

12 files changed

+99
-96
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import Argument, Node, Target
1111
from torch_tensorrt import ENABLED_FEATURES
1212
from torch_tensorrt._features import needs_not_tensorrt_rtx
13-
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
13+
from torch_tensorrt._utils import is_tensorrt_version_supported
1414
from torch_tensorrt.dynamo._settings import CompilationSettings
1515
from torch_tensorrt.dynamo._SourceIR import SourceIR
1616
from torch_tensorrt.dynamo.conversion import impl
@@ -429,7 +429,7 @@ def index_nonbool_validator(
429429
node: Node, settings: Optional[CompilationSettings] = None
430430
) -> bool:
431431
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
432-
if is_thor() or ENABLED_FEATURES.tensorrt_rtx:
432+
if ENABLED_FEATURES.tensorrt_rtx:
433433
index = node.args[1]
434434
for ind in index:
435435
if ind is not None:
@@ -3621,18 +3621,10 @@ def aten_ops_full(
36213621
)
36223622

36233623

3624-
def nonzero_validator(
3625-
node: Node, settings: Optional[CompilationSettings] = None
3626-
) -> bool:
3627-
return not is_thor()
3628-
3629-
36303624
# currently nonzero is not supported for tensorrt_rtx
36313625
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
3632-
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
36333626
@dynamo_tensorrt_converter(
36343627
torch.ops.aten.nonzero.default,
3635-
capability_validator=nonzero_validator,
36363628
supports_dynamic_shapes=True,
36373629
requires_output_allocator=True,
36383630
)

tests/py/dynamo/conversion/test_arange_aten.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt._utils import is_tegra_platform, is_thor
98

109
from .harness import DispatchTestCase
1110

1211

13-
@unittest.skipIf(
14-
is_thor() or is_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
16-
)
1712
class TestArangeConverter(DispatchTestCase):
1813
@parameterized.expand(
1914
[

tests/py/dynamo/conversion/test_cumsum_aten.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt._utils import is_tegra_platform, is_thor
98

109
from .harness import DispatchTestCase
1110

1211

13-
@unittest.skipIf(
14-
is_thor() or is_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
16-
)
1712
class TestCumsumConverter(DispatchTestCase):
1813
@parameterized.expand(
1914
[

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import ENABLED_FEATURES, Input
9-
from torch_tensorrt._utils import is_tegra_platform, is_thor
109

1110
from .harness import DispatchTestCase
1211

@@ -114,8 +113,8 @@ def forward(self, input):
114113
]
115114
)
116115
@unittest.skipIf(
117-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
118-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
116+
ENABLED_FEATURES.tensorrt_rtx,
117+
"Skipped on tensorrt_rtx due to nonzero not supported",
119118
)
120119
def test_index_constant_bool_mask(self, _, index, input):
121120
class TestModule(torch.nn.Module):
@@ -149,8 +148,8 @@ def forward(self, x, index0):
149148
)
150149

151150
@unittest.skipIf(
152-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
153-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
151+
ENABLED_FEATURES.tensorrt_rtx,
152+
"Skipped on tensorrt_rtx due to nonzero not supported",
154153
)
155154
def test_index_zero_two_dim_ITensor_mask(self):
156155
class TestModule(nn.Module):
@@ -163,10 +162,6 @@ def forward(self, x, index0):
163162
index0 = torch.tensor([True, False])
164163
self.run_test(TestModule(), [input, index0], enable_passes=True)
165164

166-
@unittest.skipIf(
167-
is_thor(),
168-
"Skipped on Thor due to nonzero not supported",
169-
)
170165
def test_index_zero_index_three_dim_ITensor(self):
171166
class TestModule(nn.Module):
172167
def forward(self, x, index0):
@@ -180,8 +175,8 @@ def forward(self, x, index0):
180175
self.run_test(TestModule(), [input, index0])
181176

182177
@unittest.skipIf(
183-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
184-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
178+
ENABLED_FEATURES.tensorrt_rtx,
179+
"Skipped on tensorrt_rtx due to nonzero not supported",
185180
)
186181
def test_index_zero_index_three_dim_mask_ITensor(self):
187182
class TestModule(nn.Module):
@@ -252,7 +247,7 @@ def forward(self, input):
252247

253248

254249
@unittest.skipIf(
255-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
250+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
256251
"nonzero is not supported for tensorrt_rtx",
257252
)
258253
class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):

tests/py/dynamo/conversion/test_nonzero_aten.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import Input
9-
from torch_tensorrt._utils import is_tegra_platform, is_thor
109

1110
from .harness import DispatchTestCase
1211

1312

1413
@unittest.skipIf(
15-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
14+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
1615
"nonzero is not supported for tensorrt_rtx",
1716
)
1817
class TestNonZeroConverter(DispatchTestCase):

tests/py/dynamo/conversion/test_sym_size.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@
44
import torch.nn as nn
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt._utils import is_thor
87

98
from .harness import DispatchTestCase
109

1110

12-
@unittest.skipIf(
13-
is_thor(),
14-
"Skipped on Thor",
15-
)
1611
class TestSymSizeConverter(DispatchTestCase):
1712
@parameterized.expand(
1813
[

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# type: ignore
22
import os
3-
import tempfile
43
import unittest
54

65
import pytest
@@ -22,7 +21,7 @@
2221

2322
@pytest.mark.unit
2423
@pytest.mark.critical
25-
def test_custom_model():
24+
def test_custom_model(tmpdir):
2625
class net(nn.Module):
2726
def __init__(self):
2827
super().__init__()
@@ -75,15 +74,15 @@ def forward(self, x, b=5, c=None, d=None):
7574
)
7675

7776
# Save the module
78-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
77+
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
7978
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
8079
# Clean up model env
8180
torch._dynamo.reset()
8281

8382

8483
@pytest.mark.unit
8584
@pytest.mark.critical
86-
def test_custom_model_with_dynamo_trace():
85+
def test_custom_model_with_dynamo_trace(tmpdir):
8786
class net(nn.Module):
8887
def __init__(self):
8988
super().__init__()
@@ -137,15 +136,15 @@ def forward(self, x, b=5, c=None, d=None):
137136
)
138137

139138
# Save the module
140-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
139+
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
141140
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
142141
# Clean up model env
143142
torch._dynamo.reset()
144143

145144

146145
@pytest.mark.unit
147146
@pytest.mark.critical
148-
def test_custom_model_with_dynamo_trace_dynamic():
147+
def test_custom_model_with_dynamo_trace_dynamic(tmpdir):
149148
class net(nn.Module):
150149
def __init__(self):
151150
super().__init__()
@@ -208,15 +207,15 @@ def forward(self, x, b=5, c=None, d=None):
208207
)
209208

210209
# Save the module
211-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
210+
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
212211
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
213212
# Clean up model env
214213
torch._dynamo.reset()
215214

216215

217216
@pytest.mark.unit
218217
@pytest.mark.critical
219-
def test_custom_model_with_dynamo_trace_kwarg_dynamic():
218+
def test_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
220219
ir = "dynamo"
221220

222221
class net(nn.Module):
@@ -298,15 +297,15 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
298297
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
299298
)
300299
# Save the module
301-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
300+
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
302301
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
303302
# Clean up model env
304303
torch._dynamo.reset()
305304

306305

307306
@pytest.mark.unit
308307
@pytest.mark.critical
309-
def test_custom_model_with_dynamo_trace_kwarg_dynamic():
308+
def test_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
310309
ir = "dynamo"
311310

312311
class net(nn.Module):
@@ -388,7 +387,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
388387
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
389388
)
390389
# Save the module
391-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
390+
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
392391
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
393392
# Clean up model env
394393
torch._dynamo.reset()

0 commit comments

Comments
 (0)