Skip to content

Commit a9fde87

Browse files
committed
add tests for aten::slice and aten::list evaluator
1 parent 5f325ec commit a9fde87

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -931,3 +931,38 @@ TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {
931931

932932
ASSERT_TRUE(jit_results[0] == trt_results[0]);
933933
}
934+
935+
TEST(Evaluators, IsAtenSliceEvaluateCorrectly) {
936+
const auto graph = R"IR(
937+
graph():
938+
%1 : int[] = prim::Constant[value= 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]()
939+
%2 : int = prim::Constant[value = 0]()
940+
%3 : int = prim::Constant[value = 7]()
941+
%4 : int = prim::Constant[value = 2]()
942+
%5 : int[] = aten::slice(%1, %2, %3, %4)
943+
return (%5))IR";
944+
945+
auto g = std::make_shared<torch::jit::Graph>();
946+
torch::jit::parseIR(graph, g.get());
947+
948+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
949+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
950+
951+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
952+
}
953+
954+
TEST(Evaluators, IsAtenListEvaluateCorrectly) {
955+
const auto graph = R"IR(
956+
graph():
957+
%1 : int[] = prim::Constant[value= 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]()
958+
%2 : int[] = aten::list(%1)
959+
return (%2))IR";
960+
961+
auto g = std::make_shared<torch::jit::Graph>();
962+
torch::jit::parseIR(graph, g.get());
963+
964+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
965+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
966+
967+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
968+
}

0 commit comments

Comments
 (0)