@@ -931,3 +931,38 @@ TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {
931
931
932
932
ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
933
933
}
934
+
935
+ TEST (Evaluators, IsAtenSliceEvaluateCorrectly) {
936
+ const auto graph = R"IR(
937
+ graph():
938
+ %1 : int[] = prim::Constant[value= 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]()
939
+ %2 : int = prim::Constant[value = 0]()
940
+ %3 : int = prim::Constant[value = 7]()
941
+ %4 : int = prim::Constant[value = 2]()
942
+ %5 : int[] = aten::slice(%1, %2, %3, %4)
943
+ return (%5))IR" ;
944
+
945
+ auto g = std::make_shared<torch::jit::Graph>();
946
+ torch::jit::parseIR (graph, g.get ());
947
+
948
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {});
949
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {});
950
+
951
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
952
+ }
953
+
954
+ TEST (Evaluators, IsAtenListEvaluateCorrectly) {
955
+ const auto graph = R"IR(
956
+ graph():
957
+ %1 : int[] = prim::Constant[value= 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]()
958
+ %2 : int[] = aten::list(%1)
959
+ return (%2))IR" ;
960
+
961
+ auto g = std::make_shared<torch::jit::Graph>();
962
+ torch::jit::parseIR (graph, g.get ());
963
+
964
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {});
965
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {});
966
+
967
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
968
+ }
0 commit comments