Skip to content

Commit 9a100b6

Browse files
authored
feat: Fixed conv1d converter when weights are Tensor (#2542)
Signed-off-by: Anurag Dixit <a.dixit91@gmail.com>
1 parent afd5abe commit 9a100b6

File tree

2 files changed

+130
-4
lines changed

2 files changed

+130
-4
lines changed

core/conversion/converters/impl/conv_deconv.cpp

+28-4
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,43 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
131131

132132
// Make a new Dims with only the spatial dimensions.
133133
nvinfer1::Dims filter_dim;
134+
nvinfer1::Dims original_dim = in->getDimensions();
134135
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
135136
TORCHTRT_CHECK(
136137
nbSpatialDims = kernel_dims.nbDims - 2,
137138
"Number of input spatial dimensions should match the kernel spatial dimensions");
138139
filter_dim.nbDims = nbSpatialDims;
139140
filter_dim.d[0] = kernel_dims.d[2];
140141
filter_dim.d[1] = kernel_dims.d[3];
142+
int32_t num_output_maps = kernel_dims.d[0];
143+
bool expand_dims = nbSpatialDims == 1;
144+
if (expand_dims) {
145+
// In case of Conv1D -> map it to 2D version
146+
// TensorRT expects nbSpatialDims = 2 or 3
147+
filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false);
148+
// Reshape input dimensions
149+
in = addPadding(ctx, n, in, 4);
150+
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
151+
kernel = addPadding(ctx, n, kernel, 4);
152+
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
153+
if (transposed) {
154+
num_output_maps = kernel_dims.d[1];
155+
}
156+
}
141157

142158
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
143159
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
144160

145161
nvinfer1::ILayer* layer = nullptr;
162+
nvinfer1::ITensor* out = nullptr;
146163
if (transposed) {
147164
// Fix padding based on output_padding provided
148165
nvinfer1::Dims begPadding = padding;
149166
bool hasOutputPadding = false;
150167
add_output_padding(padding, out_padding, hasOutputPadding);
151168

152169
nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd(
153-
*in, kernel_dims.d[0], filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
170+
*in, num_output_maps, filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
154171
deconvLayer->setStrideNd(stride);
155172
deconvLayer->setDilationNd(dilation);
156173
deconvLayer->setNbGroups(groups);
@@ -161,15 +178,21 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
161178
deconvLayer->setInput(1, *kernel);
162179
TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
163180
layer = deconvLayer;
181+
out = deconvLayer->getOutput(0);
164182
if (hasOutputPadding) {
165183
LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding);
166184
nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0);
167185
auto dims = in->getDimensions();
168186
layer = add_bias_layer(ctx, tensorPtr, dims, out_padding, bias);
187+
out = layer->getOutput(0);
188+
}
189+
if (expand_dims) {
190+
// Un-expand the expanded dimension
191+
out = addUnpadding(ctx, n, out, original_dim.nbDims);
169192
}
170193
} else {
171194
nvinfer1::IConvolutionLayer* convLayer =
172-
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
195+
ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data);
173196
convLayer->setStrideNd(stride);
174197
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
175198
convLayer->setPaddingNd(padding);
@@ -180,10 +203,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
180203
// Set conv kernel weights
181204
convLayer->setInput(1, *kernel);
182205
layer = convLayer;
206+
out = layer->getOutput(0);
183207
}
184208

185-
ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
186-
LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions());
209+
ctx->AssociateValueAndTensor(n->outputs()[0], out);
210+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
187211
return true;
188212
}
189213

tests/core/conversion/converters/test_conv_deconv.cpp

+102
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,57 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
116116
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
117117
}
118118

119+
TEST(Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) {
120+
const auto graph = R"IR(
121+
graph(%0 : Tensor,
122+
%1 : Float(4, 5, 3, strides=[15, 3, 1])):
123+
%2 : int = prim::Constant[value=-128]()
124+
%3 : float = prim::Constant[value=3.5]()
125+
%4 : int = prim::Constant[value=0]()
126+
%5 : int = prim::Constant[value=127]()
127+
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
128+
%6 : int = prim::Constant[value=6]()
129+
%7 : int = prim::Constant[value=5]()
130+
%8 : Device = prim::Constant[value="cuda:0"]()
131+
%9 : None = prim::Constant()
132+
%10 : int[] = prim::ListConstruct(%7)
133+
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
134+
%12 : int[] = prim::ListConstruct(%7)
135+
%13 : int = prim::Constant[value=1]()
136+
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
137+
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
138+
%15 : None = prim::Constant()
139+
%16 : int = prim::Constant[value=1]()
140+
%17 : int = prim::Constant[value=0]()
141+
%18 : int = prim::Constant[value=1]()
142+
%19 : int = prim::Constant[value=0]()
143+
%20 : bool = prim::Constant[value=0]()
144+
%21 : int[] = prim::ListConstruct(%16)
145+
%22 : int[] = prim::ListConstruct(%17)
146+
%23 : int[] = prim::ListConstruct(%18)
147+
%24 : int[] = prim::ListConstruct(%19)
148+
%25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %20, %24, %16, %20, %20, %20, %20)
149+
return (%25))IR";
150+
151+
auto g = std::make_shared<torch::jit::Graph>();
152+
torch::jit::parseIR(graph, g.get());
153+
154+
auto in = at::randint(1, 10, {4, 5, 3}, {at::kCUDA});
155+
auto w = at::randint(1, 2, {4, 5, 3}, {at::kCUDA});
156+
157+
auto jit_in = at::clone(in);
158+
auto jit_w = at::clone(w);
159+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w});
160+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
161+
162+
auto trt_in = at::clone(in);
163+
auto trt_w = at::clone(w);
164+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w});
165+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, nvinfer1::DataType::kINT8);
166+
167+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
168+
}
169+
119170
TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
120171
const auto graph = R"IR(
121172
graph(%0 : Tensor,
@@ -609,6 +660,57 @@ TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
609660
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
610661
}
611662

663+
TEST(Converters, ATenConv1dTransposeWithWeightTensorsConvertsCorrectly) {
664+
const auto graph = R"IR(
665+
graph(%0 : Tensor,
666+
%1 : Float(4, 5, 3, strides=[15, 3, 1])):
667+
%2 : int = prim::Constant[value=-128]()
668+
%3 : float = prim::Constant[value=3.5]()
669+
%4 : int = prim::Constant[value=0]()
670+
%5 : int = prim::Constant[value=127]()
671+
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
672+
%6 : int = prim::Constant[value=6]()
673+
%7 : int = prim::Constant[value=4]()
674+
%8 : Device = prim::Constant[value="cuda:0"]()
675+
%9 : None = prim::Constant()
676+
%10 : int[] = prim::ListConstruct(%7)
677+
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
678+
%12 : int[] = prim::ListConstruct(%7)
679+
%13 : int = prim::Constant[value=1]()
680+
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
681+
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
682+
%15 : None = prim::Constant()
683+
%16 : int = prim::Constant[value=1]()
684+
%17 : int = prim::Constant[value=0]()
685+
%18 : int = prim::Constant[value=1]()
686+
%19 : int = prim::Constant[value=0]()
687+
%20 : bool = prim::Constant[value=0]()
688+
%21 : int[] = prim::ListConstruct(%16)
689+
%22 : int[] = prim::ListConstruct(%17)
690+
%23 : int[] = prim::ListConstruct(%18)
691+
%24 : int[] = prim::ListConstruct(%19)
692+
%25 : bool = prim::Constant[value=1]()
693+
%26 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %25, %24, %18, %20, %20, %20, %20)
694+
return (%26))IR";
695+
auto g = std::make_shared<torch::jit::Graph>();
696+
torch::jit::parseIR(graph, g.get());
697+
698+
auto in = at::randint(1, 10, {4, 5, 3}, {at::kCUDA});
699+
auto w = at::randint(1, 2, {5, 4, 3}, {at::kCUDA});
700+
701+
auto jit_in = at::clone(in);
702+
auto jit_w = at::clone(w);
703+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w});
704+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
705+
706+
auto trt_in = at::clone(in);
707+
auto trt_w = at::clone(w);
708+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w});
709+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, nvinfer1::DataType::kINT8);
710+
711+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
712+
}
713+
612714
TEST(Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
613715
const auto graph = R"IR(
614716
graph(%0 : Tensor,

0 commit comments

Comments
 (0)