Skip to content

Commit

Permalink
[quant] Make conv2d_prepack and linear_prepack pure (pytorch#35073)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#35073

We want to do constant propagation for quantize_per_tensor/quantize_per_channel
which will produce results that's consumed by these ops, and since we need to
make sure the output of the node has no writer before constant prop through the node,
the consumer needs to be pure as well.

Test Plan:
see next PR

Imported from OSS

Differential Revision: D20655310

fbshipit-source-id: 3e33662224c21b889c8121b823f8ce0b7da75eed
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Mar 27, 2020
1 parent e1773f2 commit 04a3345
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
24 changes: 14 additions & 10 deletions aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,23 @@ static auto registry =
c10::RegisterOperators()
.op("quantized::conv_prepack", // conv_prepack is deprecated, please use
// conv2d_prepack for 2D conv.
c10::RegisterOperators::options().kernel<QConvPackWeightInt8<2>>(
DispatchKey::QuantizedCPUTensorId))
.op("quantized::conv2d_prepack", // We use conv2d_prepack to be
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QConvPackWeightInt8<2>>(DispatchKey::QuantizedCPUTensorId))
.op("quantized::conv2d_prepack", // We use conv2d_prepack to be
// consistent with conv3d_prepack
c10::RegisterOperators::options().kernel<QConvPackWeightInt8<2>>(
DispatchKey::QuantizedCPUTensorId))
.op("_quantized::conv2d_prepack", // We use conv2d_prepack to be
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QConvPackWeightInt8<2>>(DispatchKey::QuantizedCPUTensorId))
.op("_quantized::conv2d_prepack", // We use conv2d_prepack to be
// consistent with conv3d_prepack
c10::RegisterOperators::options().kernel<QConvPackWeightInt8<2>>(
DispatchKey::QuantizedCPUTensorId))
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QConvPackWeightInt8<2>>(DispatchKey::QuantizedCPUTensorId))
.op("quantized::conv3d_prepack",
c10::RegisterOperators::options().kernel<QConvPackWeightInt8<3>>(
DispatchKey::QuantizedCPUTensorId));
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QConvPackWeightInt8<3>>(DispatchKey::QuantizedCPUTensorId));

} // namespace
} // namespace native
Expand Down
20 changes: 12 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,21 @@ class QLinearPackWeightFp16 final : public c10::OperatorKernel {
static auto registry =
c10::RegisterOperators()
.op("quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack",
c10::RegisterOperators::options().kernel<QLinearPackWeightInt8>(
DispatchKey::QuantizedCPUTensorId))
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QLinearPackWeightInt8>(DispatchKey::QuantizedCPUTensorId))
.op("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack",
c10::RegisterOperators::options().kernel<QLinearPackWeightFp16>(
DispatchKey::CPUTensorId))
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QLinearPackWeightFp16>(DispatchKey::CPUTensorId))
.op("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack",
c10::RegisterOperators::options().kernel<QLinearPackWeightInt8>(
DispatchKey::QuantizedCPUTensorId))
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QLinearPackWeightInt8>(DispatchKey::QuantizedCPUTensorId))
.op("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack",
c10::RegisterOperators::options().kernel<QLinearPackWeightFp16>(
DispatchKey::CPUTensorId));
c10::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<QLinearPackWeightFp16>(DispatchKey::CPUTensorId));

} // namespace
} // namespace native
Expand Down

0 comments on commit 04a3345

Please sign in to comment.