@@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
589
589
// Validate identical graphs after pooling constants and canonicalizing
590
590
ASSERT_TRUE ((tg->toString () == sg->toString ()));
591
591
}
592
+
593
+ TEST (LoweringPasses, RemoveCollectionCastTuple) {
594
+ // Ensure the lowering pass transforms the first graph into the second
595
+ std::string source_graph = R"IR(
596
+ graph(%x.1 : Tensor):
597
+ %3 : int = prim::Constant[value=1]()
598
+ %2 : int = prim::Constant[value=2]()
599
+ %a.1 : Tensor = aten::mul(%x.1, %2)
600
+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
601
+ %c.1 : Tensor = aten::relu(%b.1)
602
+ %d.1 : Tensor = aten::sqrt(%c.1)
603
+ %8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
604
+ return (%8))IR" ;
605
+
606
+ std::string target_graph = R"IR(
607
+ graph(%x.1 : Tensor):
608
+ %3 : int = prim::Constant[value=1]()
609
+ %2 : int = prim::Constant[value=2]()
610
+ %a.1 : Tensor = aten::mul(%x.1, %2)
611
+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
612
+ %c.1 : Tensor = aten::relu(%b.1)
613
+ %d.1 : Tensor = aten::sqrt(%c.1)
614
+ return (%c.1, %d.1, %b.1))IR" ;
615
+
616
+ // Ensure the lowering pass transforms the first graph into the second
617
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
618
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
619
+ auto sg = std::make_shared<torch::jit::Graph>();
620
+ torch::jit::parseIR (source_graph, sg.get ());
621
+
622
+ torch_tensorrt::core::lowering::passes::RemoveCollectionCast (sg);
623
+ torch::jit::ConstantPooling (sg);
624
+ sg = torch::jit::Canonicalize (sg, false );
625
+
626
+ auto tg = std::make_shared<torch::jit::Graph>();
627
+ torch::jit::parseIR (target_graph, tg.get ());
628
+
629
+ torch::jit::ConstantPooling (tg);
630
+ tg = torch::jit::Canonicalize (tg, false );
631
+
632
+ // Validate identical graphs after pooling constants and canonicalizing
633
+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
634
+ }
635
+
636
+ TEST (LoweringPasses, RemoveCollectionCastList) {
637
+ // Ensure the lowering pass transforms the first graph into the second
638
+ std::string source_graph = R"IR(
639
+ graph(%x.1 : Tensor):
640
+ %3 : int = prim::Constant[value=1]()
641
+ %2 : int = prim::Constant[value=2]()
642
+ %a.1 : Tensor = aten::mul(%x.1, %2)
643
+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
644
+ %c.1 : Tensor = aten::relu(%b.1)
645
+ %d.1 : Tensor = aten::sqrt(%c.1)
646
+ %8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
647
+ return (%8))IR" ;
648
+
649
+ std::string target_graph = R"IR(
650
+ graph(%x.1 : Tensor):
651
+ %3 : int = prim::Constant[value=1]()
652
+ %2 : int = prim::Constant[value=2]()
653
+ %a.1 : Tensor = aten::mul(%x.1, %2)
654
+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
655
+ %c.1 : Tensor = aten::relu(%b.1)
656
+ %d.1 : Tensor = aten::sqrt(%c.1)
657
+ return (%b.1, %c.1, %d.1))IR" ;
658
+
659
+ // Ensure the lowering pass transforms the first graph into the second
660
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
661
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
662
+ auto sg = std::make_shared<torch::jit::Graph>();
663
+ torch::jit::parseIR (source_graph, sg.get ());
664
+
665
+ torch_tensorrt::core::lowering::passes::RemoveCollectionCast (sg);
666
+ torch::jit::ConstantPooling (sg);
667
+ sg = torch::jit::Canonicalize (sg, false );
668
+
669
+ auto tg = std::make_shared<torch::jit::Graph>();
670
+ torch::jit::parseIR (target_graph, tg.get ());
671
+
672
+ torch::jit::ConstantPooling (tg);
673
+ tg = torch::jit::Canonicalize (tg, false );
674
+
675
+ // Validate identical graphs after pooling constants and canonicalizing
676
+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
677
+ }
0 commit comments