@@ -116,6 +116,57 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
116116 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
117117}
118118
119+ TEST (Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) {
120+ const auto graph = R"IR(
121+ graph(%0 : Tensor,
122+ %1 : Float(4, 5, 3, strides=[15, 3, 1])):
123+ %2 : int = prim::Constant[value=-128]()
124+ %3 : float = prim::Constant[value=3.5]()
125+ %4 : int = prim::Constant[value=0]()
126+ %5 : int = prim::Constant[value=127]()
127+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
128+ %6 : int = prim::Constant[value=6]()
129+ %7 : int = prim::Constant[value=5]()
130+ %8 : Device = prim::Constant[value="cuda:0"]()
131+ %9 : None = prim::Constant()
132+ %10 : int[] = prim::ListConstruct(%7)
133+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
134+ %12 : int[] = prim::ListConstruct(%7)
135+ %13 : int = prim::Constant[value=1]()
136+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
137+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
138+ %15 : None = prim::Constant()
139+ %16 : int = prim::Constant[value=1]()
140+ %17 : int = prim::Constant[value=0]()
141+ %18 : int = prim::Constant[value=1]()
142+ %19 : int = prim::Constant[value=0]()
143+ %20 : bool = prim::Constant[value=0]()
144+ %21 : int[] = prim::ListConstruct(%16)
145+ %22 : int[] = prim::ListConstruct(%17)
146+ %23 : int[] = prim::ListConstruct(%18)
147+ %24 : int[] = prim::ListConstruct(%19)
148+ %25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %20, %24, %16, %20, %20, %20, %20)
149+ return (%25))IR" ;
150+
151+ auto g = std::make_shared<torch::jit::Graph>();
152+ torch::jit::parseIR (graph, g.get ());
153+
154+ auto in = at::randint (1 , 10 , {4 , 5 , 3 }, {at::kCUDA });
155+ auto w = at::randint (1 , 2 , {4 , 5 , 3 }, {at::kCUDA });
156+
157+ auto jit_in = at::clone (in);
158+ auto jit_w = at::clone (w);
159+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w});
160+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
161+
162+ auto trt_in = at::clone (in);
163+ auto trt_w = at::clone (w);
164+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w});
165+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, nvinfer1::DataType::kINT8 );
166+
167+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
168+ }
169+
119170TEST (Converters, ATenConvolutionNoBiasConvertsCorrectly) {
120171 const auto graph = R"IR(
121172 graph(%0 : Tensor,
@@ -609,6 +660,57 @@ TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
609660 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
610661}
611662
663+ TEST (Converters, ATenConv1dTransposeWithWeightTensorsConvertsCorrectly) {
664+ const auto graph = R"IR(
665+ graph(%0 : Tensor,
666+ %1 : Float(4, 5, 3, strides=[15, 3, 1])):
667+ %2 : int = prim::Constant[value=-128]()
668+ %3 : float = prim::Constant[value=3.5]()
669+ %4 : int = prim::Constant[value=0]()
670+ %5 : int = prim::Constant[value=127]()
671+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
672+ %6 : int = prim::Constant[value=6]()
673+ %7 : int = prim::Constant[value=4]()
674+ %8 : Device = prim::Constant[value="cuda:0"]()
675+ %9 : None = prim::Constant()
676+ %10 : int[] = prim::ListConstruct(%7)
677+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
678+ %12 : int[] = prim::ListConstruct(%7)
679+ %13 : int = prim::Constant[value=1]()
680+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
681+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
682+ %15 : None = prim::Constant()
683+ %16 : int = prim::Constant[value=1]()
684+ %17 : int = prim::Constant[value=0]()
685+ %18 : int = prim::Constant[value=1]()
686+ %19 : int = prim::Constant[value=0]()
687+ %20 : bool = prim::Constant[value=0]()
688+ %21 : int[] = prim::ListConstruct(%16)
689+ %22 : int[] = prim::ListConstruct(%17)
690+ %23 : int[] = prim::ListConstruct(%18)
691+ %24 : int[] = prim::ListConstruct(%19)
692+ %25 : bool = prim::Constant[value=1]()
693+ %26 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %25, %24, %18, %20, %20, %20, %20)
694+ return (%26))IR" ;
695+ auto g = std::make_shared<torch::jit::Graph>();
696+ torch::jit::parseIR (graph, g.get ());
697+
698+ auto in = at::randint (1 , 10 , {4 , 5 , 3 }, {at::kCUDA });
699+ auto w = at::randint (1 , 2 , {5 , 4 , 3 }, {at::kCUDA });
700+
701+ auto jit_in = at::clone (in);
702+ auto jit_w = at::clone (w);
703+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w});
704+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
705+
706+ auto trt_in = at::clone (in);
707+ auto trt_w = at::clone (w);
708+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w});
709+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, nvinfer1::DataType::kINT8 );
710+
711+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
712+ }
713+
612714TEST (Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
613715 const auto graph = R"IR(
614716 graph(%0 : Tensor,
0 commit comments