66from parameterized import parameterized
77from torch .testing ._internal .common_utils import run_tests
88from torch_tensorrt import ENABLED_FEATURES , Input
9- from torch_tensorrt ._utils import is_tegra_platform , is_thor
109
1110from .harness import DispatchTestCase
1211
@@ -114,8 +113,8 @@ def forward(self, input):
114113 ]
115114 )
116115 @unittest .skipIf (
117- is_thor () or ENABLED_FEATURES .tensorrt_rtx ,
118- "Skipped on Thor or tensorrt_rtx due to nonzero not supported" ,
116+ ENABLED_FEATURES .tensorrt_rtx ,
117+ "Skipped on tensorrt_rtx due to nonzero not supported" ,
119118 )
120119 def test_index_constant_bool_mask (self , _ , index , input ):
121120 class TestModule (torch .nn .Module ):
@@ -149,8 +148,8 @@ def forward(self, x, index0):
149148 )
150149
151150 @unittest .skipIf (
152- is_thor () or ENABLED_FEATURES .tensorrt_rtx ,
153- "Skipped on Thor or tensorrt_rtx due to nonzero not supported" ,
151+ ENABLED_FEATURES .tensorrt_rtx ,
152+ "Skipped on tensorrt_rtx due to nonzero not supported" ,
154153 )
155154 def test_index_zero_two_dim_ITensor_mask (self ):
156155 class TestModule (nn .Module ):
@@ -163,10 +162,6 @@ def forward(self, x, index0):
163162 index0 = torch .tensor ([True , False ])
164163 self .run_test (TestModule (), [input , index0 ], enable_passes = True )
165164
166- @unittest .skipIf (
167- is_thor (),
168- "Skipped on Thor due to nonzero not supported" ,
169- )
170165 def test_index_zero_index_three_dim_ITensor (self ):
171166 class TestModule (nn .Module ):
172167 def forward (self , x , index0 ):
@@ -180,8 +175,8 @@ def forward(self, x, index0):
180175 self .run_test (TestModule (), [input , index0 ])
181176
182177 @unittest .skipIf (
183- is_thor () or ENABLED_FEATURES .tensorrt_rtx ,
184- "Skipped on Thor or tensorrt_rtx due to nonzero not supported" ,
178+ ENABLED_FEATURES .tensorrt_rtx ,
179+ "Skipped on tensorrt_rtx due to nonzero not supported" ,
185180 )
186181 def test_index_zero_index_three_dim_mask_ITensor (self ):
187182 class TestModule (nn .Module ):
@@ -252,7 +247,7 @@ def forward(self, input):
252247
253248
254249@unittest .skipIf (
255- torch_tensorrt .ENABLED_FEATURES .tensorrt_rtx or is_thor () or is_tegra_platform () ,
250+ torch_tensorrt .ENABLED_FEATURES .tensorrt_rtx ,
256251 "nonzero is not supported for tensorrt_rtx" ,
257252)
258253class TestIndexDynamicInputNonDynamicIndexConverter (DispatchTestCase ):
0 commit comments