@@ -116,6 +116,57 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
116
116
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
117
117
}
118
118
119
+ TEST (Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) {
120
+ const auto graph = R"IR(
121
+ graph(%0 : Tensor,
122
+ %1 : Float(4, 5, 3, strides=[15, 3, 1])):
123
+ %2 : int = prim::Constant[value=-128]()
124
+ %3 : float = prim::Constant[value=3.5]()
125
+ %4 : int = prim::Constant[value=0]()
126
+ %5 : int = prim::Constant[value=127]()
127
+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
128
+ %6 : int = prim::Constant[value=6]()
129
+ %7 : int = prim::Constant[value=5]()
130
+ %8 : Device = prim::Constant[value="cuda:0"]()
131
+ %9 : None = prim::Constant()
132
+ %10 : int[] = prim::ListConstruct(%7)
133
+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
134
+ %12 : int[] = prim::ListConstruct(%7)
135
+ %13 : int = prim::Constant[value=1]()
136
+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
137
+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
138
+ %15 : None = prim::Constant()
139
+ %16 : int = prim::Constant[value=1]()
140
+ %17 : int = prim::Constant[value=0]()
141
+ %18 : int = prim::Constant[value=1]()
142
+ %19 : int = prim::Constant[value=0]()
143
+ %20 : bool = prim::Constant[value=0]()
144
+ %21 : int[] = prim::ListConstruct(%16)
145
+ %22 : int[] = prim::ListConstruct(%17)
146
+ %23 : int[] = prim::ListConstruct(%18)
147
+ %24 : int[] = prim::ListConstruct(%19)
148
+ %25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %20, %24, %16, %20, %20, %20, %20)
149
+ return (%25))IR" ;
150
+
151
+ auto g = std::make_shared<torch::jit::Graph>();
152
+ torch::jit::parseIR (graph, g.get ());
153
+
154
+ auto in = at::randint (1 , 10 , {4 , 5 , 3 }, {at::kCUDA });
155
+ auto w = at::randint (1 , 2 , {4 , 5 , 3 }, {at::kCUDA });
156
+
157
+ auto jit_in = at::clone (in);
158
+ auto jit_w = at::clone (w);
159
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w});
160
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
161
+
162
+ auto trt_in = at::clone (in);
163
+ auto trt_w = at::clone (w);
164
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w});
165
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, nvinfer1::DataType::kINT8 );
166
+
167
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
168
+ }
169
+
119
170
TEST (Converters, ATenConvolutionNoBiasConvertsCorrectly) {
120
171
const auto graph = R"IR(
121
172
graph(%0 : Tensor,
@@ -609,6 +660,57 @@ TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
609
660
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
610
661
}
611
662
663
+ TEST (Converters, ATenConv1dTransposeWithWeightTensorsConvertsCorrectly) {
664
+ const auto graph = R"IR(
665
+ graph(%0 : Tensor,
666
+ %1 : Float(4, 5, 3, strides=[15, 3, 1])):
667
+ %2 : int = prim::Constant[value=-128]()
668
+ %3 : float = prim::Constant[value=3.5]()
669
+ %4 : int = prim::Constant[value=0]()
670
+ %5 : int = prim::Constant[value=127]()
671
+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
672
+ %6 : int = prim::Constant[value=6]()
673
+ %7 : int = prim::Constant[value=4]()
674
+ %8 : Device = prim::Constant[value="cuda:0"]()
675
+ %9 : None = prim::Constant()
676
+ %10 : int[] = prim::ListConstruct(%7)
677
+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
678
+ %12 : int[] = prim::ListConstruct(%7)
679
+ %13 : int = prim::Constant[value=1]()
680
+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
681
+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
682
+ %15 : None = prim::Constant()
683
+ %16 : int = prim::Constant[value=1]()
684
+ %17 : int = prim::Constant[value=0]()
685
+ %18 : int = prim::Constant[value=1]()
686
+ %19 : int = prim::Constant[value=0]()
687
+ %20 : bool = prim::Constant[value=0]()
688
+ %21 : int[] = prim::ListConstruct(%16)
689
+ %22 : int[] = prim::ListConstruct(%17)
690
+ %23 : int[] = prim::ListConstruct(%18)
691
+ %24 : int[] = prim::ListConstruct(%19)
692
+ %25 : bool = prim::Constant[value=1]()
693
+ %26 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %25, %24, %18, %20, %20, %20, %20)
694
+ return (%26))IR" ;
695
+ auto g = std::make_shared<torch::jit::Graph>();
696
+ torch::jit::parseIR (graph, g.get ());
697
+
698
+ auto in = at::randint (1 , 10 , {4 , 5 , 3 }, {at::kCUDA });
699
+ auto w = at::randint (1 , 2 , {5 , 4 , 3 }, {at::kCUDA });
700
+
701
+ auto jit_in = at::clone (in);
702
+ auto jit_w = at::clone (w);
703
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w});
704
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
705
+
706
+ auto trt_in = at::clone (in);
707
+ auto trt_w = at::clone (w);
708
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w});
709
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, nvinfer1::DataType::kINT8 );
710
+
711
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
712
+ }
713
+
612
714
TEST (Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
613
715
const auto graph = R"IR(
614
716
graph(%0 : Tensor,
0 commit comments