From 04a33453358ec5aef96d481cfc759b3ebaaf6f10 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 27 Mar 2020 14:12:11 -0700 Subject: [PATCH] [quant] Make conv2d_prepack and linear_prepack pure (#35073) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35073 We want to do constant propagation for quantize_per_tensor/quantize_per_channel which will produce results that's consumed by these ops, and since we need to make sure the output of the node has no writer before constant prop through the node, the consumer needs to be pure as well. Test Plan: see next PR Imported from OSS Differential Revision: D20655310 fbshipit-source-id: 3e33662224c21b889c8121b823f8ce0b7da75eed --- .../native/quantized/cpu/qconv_prepack.cpp | 24 +++++++++++-------- .../native/quantized/cpu/qlinear_prepack.cpp | 20 +++++++++------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 6bc2e5a8be3ff4..0601f2d913d7cf 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -305,19 +305,23 @@ static auto registry = c10::RegisterOperators() .op("quantized::conv_prepack", // conv_prepack is deprecated, please use // conv2d_prepack for 2D conv. - c10::RegisterOperators::options().kernel>( - DispatchKey::QuantizedCPUTensorId)) - .op("quantized::conv2d_prepack", // We use conv2d_prepack to be + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel>(DispatchKey::QuantizedCPUTensorId)) + .op("quantized::conv2d_prepack", // We use conv2d_prepack to be // consistent with conv3d_prepack - c10::RegisterOperators::options().kernel>( - DispatchKey::QuantizedCPUTensorId)) - .op("_quantized::conv2d_prepack", // We use conv2d_prepack to be + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel>(DispatchKey::QuantizedCPUTensorId)) + .op("_quantized::conv2d_prepack", // We use conv2d_prepack to be // consistent with conv3d_prepack - c10::RegisterOperators::options().kernel>( - DispatchKey::QuantizedCPUTensorId)) + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv3d_prepack", - c10::RegisterOperators::options().kernel>( - DispatchKey::QuantizedCPUTensorId)); + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel>(DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 2a6d93a3b50ac8..e7fc82edc7599e 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -298,17 +298,21 @@ class QLinearPackWeightFp16 final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators() .op("quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack", - c10::RegisterOperators::options().kernel( - DispatchKey::QuantizedCPUTensorId)) + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel(DispatchKey::QuantizedCPUTensorId)) .op("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack", - c10::RegisterOperators::options().kernel( - DispatchKey::CPUTensorId)) + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel(DispatchKey::CPUTensorId)) .op("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack", - c10::RegisterOperators::options().kernel( - DispatchKey::QuantizedCPUTensorId)) + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel(DispatchKey::QuantizedCPUTensorId)) .op("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack", - c10::RegisterOperators::options().kernel( - DispatchKey::CPUTensorId)); + c10::RegisterOperators::options() + .aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION) + .kernel(DispatchKey::CPUTensorId)); } // namespace } // namespace native