Skip to content

Commit

Permalink
[tensorexpr][nnc] Support quantization (pytorch#66676)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#66676

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31676329

Pulled By: IvanKobzarev

fbshipit-source-id: 288b41ff4ed603dfaacb465f296997f14bb23c22
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Nov 1, 2021
1 parent 97f29bd commit 7fbcf79
Show file tree
Hide file tree
Showing 22 changed files with 1,157 additions and 11 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ namespace c10 {
_(aten, hardswish_) \
_(aten, hardsigmoid_) \
_(aten, hardtanh_) \
_(aten, quantize_per_tensor) \
_(aten, dequantize) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qadd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,9 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
}

} // namespace

Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
return qadd<false>(qa, qb, scale, zero_point);
}

}} // namespace at::native
8 changes: 8 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qadd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <ATen/ATen.h>

namespace at {
namespace native {
TORCH_API Tensor
quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point);
}
} // namespace at
27 changes: 26 additions & 1 deletion binaries/aot_model_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/script.h>


C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(model_name, "", "The name of the model.");
C10_DEFINE_string(model_version, "", "The version of the model.");
Expand Down Expand Up @@ -166,7 +173,25 @@ int main(int argc, char** argv) {
m.eval();
auto frozen_m = torch::jit::freeze_module(m.clone());
auto graph = frozen_m.get_method(FLAGS_method_name).graph();
auto input_shapes = parseInputShapes();
std::vector<c10::optional<at::Tensor>> example_inputs;
example_inputs.reserve(input_shapes.size());
for (const auto& input_shape : input_shapes) {
example_inputs.emplace_back(at::rand(input_shape));
}

torch::jit::RemoveTensorMutation(graph);
torch::jit::EliminateDeadCode(graph->block());
graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);

torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
torch::jit::OptimizeFrozenGraph(graph, true);
torch::jit::PropagateShapesOnGraph(graph);
torch::jit::PeepholeOptimize(graph, false);
torch::jit::ConstantPropagation(graph);
torch::jit::PropagateShapesOnGraph(graph);
torch::jit::PeepholeOptimize(graph, false);
torch::jit::ConstantPropagation(graph);

auto compile_spec = createCompileSpec();
auto any_dict_ty =
Expand Down
1 change: 1 addition & 0 deletions test/cpp/tensorexpr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(TENSOREXPR_TEST_SRCS
${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp
${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp
${TENSOREXPR_TEST_ROOT}/test_ops.cpp
${TENSOREXPR_TEST_ROOT}/test_quantization.cpp
${TENSOREXPR_TEST_ROOT}/test_reductions.cpp
${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp
${TENSOREXPR_TEST_ROOT}/test_simplify.cpp
Expand Down
231 changes: 231 additions & 0 deletions test/cpp/tensorexpr/test_quantization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#include <gtest/gtest.h>

#include <ATen/native/quantized/cpu/conv_packed_params.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
#include <cmath>
#include <sstream>
#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir.h"

namespace torch {
namespace jit {

using namespace torch::jit::tensorexpr;
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
using namespace torch::indexing;
using namespace torch::jit::tensorexpr;

class Quantization : public ::testing::Test {
public:
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
void SetUp() {
getTEMustUseLLVMOnCPU() = false;
}
};

TEST_F(Quantization, QuantDequantInt8) {
const auto graph_string = R"IR(
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=12]()
%3 : int = prim::Constant[value=13]()
%4 : float = prim::Constant[value=0.1]()
%q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
%6 : Float(2, 2) = aten::dequantize(%q.1)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
auto y_expected = at::dequantize(q);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantDequantUInt8) {
const auto graph_string = R"IR(
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=13]()
%3 : int = prim::Constant[value=122]()
%4 : float = prim::Constant[value=0.1]()
%q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
%6 : Float(2, 2) = aten::dequantize(%q.1)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
auto y_expected = at::dequantize(q);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

at::Tensor quantized_add(
at::Tensor x1,
at::Tensor x2,
double scale,
int64_t zero) {
const auto qadd_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("quantized::add", "")
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
return qadd_op.call(x1, x2, scale, zero);
}

TEST_F(Quantization, QuantAddDequantInt8) {
const auto graph_string = R"IR(
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=12]()
%qz1 : int = prim::Constant[value=13]()
%qs1 : float = prim::Constant[value=0.1]()
%qz2 : int = prim::Constant[value=13]()
%qs2 : float = prim::Constant[value=0.1]()
%qza : int = prim::Constant[value=13]()
%qsa : float = prim::Constant[value=0.1]()
%q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
%q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
%qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
%6 : Float(2, 2) = aten::dequantize(%qa)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
auto qa = quantized_add(q1, q2, 0.1f, 13);
auto y_expected = at::dequantize(qa);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x1, x2};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x1:\n" << x1 << std::endl;
std::cout << "q1:\n" << q1 << std::endl;
std::cout << "x2:\n" << x2 << std::endl;
std::cout << "q2:\n" << q2 << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantAddDequantUInt8) {
const auto graph_string = R"IR(
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=13]()
%qz1 : int = prim::Constant[value=13]()
%qs1 : float = prim::Constant[value=0.1]()
%qz2 : int = prim::Constant[value=13]()
%qs2 : float = prim::Constant[value=0.1]()
%qza : int = prim::Constant[value=13]()
%qsa : float = prim::Constant[value=0.1]()
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
%qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
%6 : Float(2, 2) = aten::dequantize(%qa)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
auto qa = quantized_add(q1, q2, 0.1f, 13);
auto y_expected = at::dequantize(qa);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x1, x2};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x1:\n" << x1 << std::endl;
std::cout << "q1:\n" << q1 << std::endl;
std::cout << "x2:\n" << x2 << std::endl;
std::cout << "q2:\n" << q2 << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
const auto graph_string = R"IR(
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
%2 : int = prim::Constant[value=13]()
%4 : NoneType = prim::Constant()
%3 : int[] = prim::Constant[value=[4, 4]]()
%qz : int = prim::Constant[value=13]()
%qs : float = prim::Constant[value=0.1]()
%q : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
%qu : QUInt8(1, 1, 4, 4) = aten::upsample_nearest2d(%q, %3, %4)
%6 : Float(1, 1, 4, 4) = aten::dequantize(%qu)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
auto qu = at::upsample_nearest2d(q, {4, 4});
auto y_expected = at::dequantize(qu);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x:\n" << x << std::endl;
std::cout << "q:\n" << q << std::endl;
std::cout << "qu:\n" << qu << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,7 @@ def test(x, y, z):
'ceil',
'clamp',
'clamp.scalar',
'contiguous',
'cos',
'cosh',
'div.no_rounding_mode',
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/tensorexpr/operators/misc.cpp",
"torch/csrc/jit/tensorexpr/operators/norm.cpp",
"torch/csrc/jit/tensorexpr/operators/pointwise.cpp",
"torch/csrc/jit/tensorexpr/operators/quantization.cpp",
"torch/csrc/jit/tensorexpr/operators/reduction.cpp",
"torch/csrc/jit/tensorexpr/operators/softmax.cpp",
"torch/csrc/jit/tensorexpr/reduction.cpp",
Expand Down
Loading

0 comments on commit 7fbcf79

Please sign in to comment.