Skip to content

Commit bf46d54

Browse files
authored
Merge pull request #2107 from andi4191/anurag.dixit/aten_fake_quant
feat: Added a variant for aten::fake_quant_per_tensor
2 parents 3c49608 + 6e51901 commit bf46d54

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed

core/conversion/converters/impl/quantization.cpp

+25-11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ namespace {
1111

1212
#if NV_TENSORRT_MAJOR > 7
1313
// clang-format off
14+
15+
bool add_qdq(ConversionCtx *ctx, const torch::jit::Node* n, nvinfer1::ITensor* input, nvinfer1::ITensor* scale, std::string& opName) {
16+
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale);
17+
TORCHTRT_CHECK(quantize_layer, "Unable to create QuantizeLayer from node: " << *n);
18+
quantize_layer->setAxis(0);
19+
20+
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale);
21+
TORCHTRT_CHECK(dequantize_layer, "Unable to create DequantizeLayer from node: " << *n);
22+
dequantize_layer->setAxis(0);
23+
24+
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
25+
LOG_DEBUG("[" << opName << "]"<< " Output tensor shape: " << qdq_out->getDimensions());
26+
27+
return true;
28+
}
29+
1430
auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
1531
.pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)",
1632
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -20,18 +36,16 @@ auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
2036
auto scale = args[1].unwrapToScalar().to<float>();
2137
auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale}));
2238
// Add and configure a QuantizeLayer.
23-
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor);
24-
quantize_layer->setAxis(0);
25-
26-
// Add and configure DequantizeLayer following a QuantizeLayer
27-
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor);
28-
dequantize_layer->setAxis(0);
29-
30-
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
31-
LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions());
32-
33-
return true;
39+
std::string opName("aten::fake_quantize_per_tensor_affine");
40+
return add_qdq(ctx, n, input, scaleTensor, opName);
3441
}})
42+
.pattern({"aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> (Tensor)",
43+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
44+
auto input = args[0].ITensorOrFreeze(ctx);
45+
auto scale = args[1].ITensorOrFreeze(ctx);
46+
std::string opName("aten::fake_quantize_per_tensor_affine.tensor_qparams");
47+
return add_qdq(ctx, n, input, scale, opName);
48+
}})
3549
.pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)",
3650
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3751
// This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API.

tests/core/conversion/converters/test_quantization.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,40 @@ TEST(Converters, ATenFakeQuantizePerTensorConvertsCorrectly) {
3030
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3131
}
3232

33+
TEST(Converters, ATenFakeQuantizePerTensorWithParamsConvertsCorrectly) {
34+
const auto graph = R"IR(
35+
graph(%x.1 : Tensor):
36+
%22 : int = prim::Constant[value=-128]()
37+
%14 : int = prim::Constant[value=4]()
38+
%9 : None = prim::Constant()
39+
%35 : Device = prim::Constant[value="cuda:0"]()
40+
%6 : int = prim::Constant[value=6]()
41+
%7 : int = prim::Constant[value=3]()
42+
%3 : int = prim::Constant[value=1]()
43+
%5 : float = prim::Constant[value=3.5]()
44+
%13 : int = prim::Constant[value=1]()
45+
%23 : int = prim::Constant[value=127]()
46+
%4 : int[] = prim::ListConstruct(%3)
47+
%11 : Tensor = aten::full(%4, %5, %6, %9, %35, %9)
48+
%12 : int[] = prim::ListConstruct(%3)
49+
%19 : Tensor = aten::full(%12, %13, %7, %9, %35, %9)
50+
%quant_input.1 : Tensor = aten::fake_quantize_per_tensor_affine(%x.1, %11, %19, %22, %23)
51+
return (%quant_input.1))IR";
52+
53+
auto g = std::make_shared<torch::jit::Graph>();
54+
torch::jit::parseIR(graph, g.get());
55+
56+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}).to(at::kFloat);
57+
58+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
59+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
60+
61+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
62+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}, nvinfer1::DataType::kINT8);
63+
64+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
65+
}
66+
3367
TEST(Converters, ATenFakeQuantizePerChannelConvertsCorrectly) {
3468
const auto graph = R"IR(
3569
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)