Skip to content

Commit

Permalink
[quant][be] Refactor the error checking code for quantize_per_channel…
Browse files Browse the repository at this point in the history
… op (pytorch#89271)

Summary:
at

Test Plan:
make sure it compiles

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#89271
Approved by: https://github.com/andrewor14
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Nov 23, 2022
1 parent 71c0e84 commit 128faf2
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions aten/src/ATen/native/quantized/AffineQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ void checkSameSize(
" only works with Tensors with the same shape");
}

void checkPerChannelParamsSize(
const Tensor& rtensor,
int64_t axis,
const Tensor& scales,
const Tensor& zero_points
) {
int64_t channel = rtensor.size(axis);
TORCH_CHECK(
channel == int64_t(scales.numel()),
"length of scales must equal to channel, expected ", channel, " got, ", scales.numel());
TORCH_CHECK(
channel == int64_t(zero_points.numel()),
"length of zero_points must equal to channel expected ", channel, " got, ", zero_points.numel());
}

} // anonymous namespace

Tensor& quantize_tensor_per_tensor_affine(
Expand Down Expand Up @@ -156,13 +171,7 @@ Tensor& quantize_tensor_per_channel_affine(
"Expected: [0, ",
rtensor.dim(),
")");
int64_t channel = rtensor.size(axis);
TORCH_CHECK(
channel == int64_t(scales.numel()),
"length of scales must equal to channel");
TORCH_CHECK(
channel == int64_t(zero_points.numel()),
"length of zero_points must equal to channel");
checkPerChannelParamsSize(rtensor, axis, scales, zero_points);

quantize_tensor_per_channel_affine_stub(
rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
Expand Down Expand Up @@ -195,13 +204,7 @@ Tensor& quantize_tensor_per_channel_float_qparams(
"Expected: [0, ",
rtensor.dim(),
")");
int64_t channel = rtensor.size(axis);
TORCH_CHECK(
channel == int64_t(scales.numel()),
"length of scales must equal to channel");
TORCH_CHECK(
channel == int64_t(zero_points.numel()),
"length of zero_points must equal to channel");
checkPerChannelParamsSize(rtensor, axis, scales, zero_points);

quantize_tensor_per_channel_float_qparams_stub(
rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
Expand Down Expand Up @@ -260,13 +263,7 @@ Tensor& dequantize_tensor_per_channel_affine(
" Expected: [0, ",
qtensor.dim(),
")");
int64_t channel = qtensor.size(axis);
TORCH_CHECK(
channel == int64_t(scales.numel()),
"length of scales must equal to channel");
TORCH_CHECK(
channel == int64_t(zero_points.numel()),
"length of zero_points must equal to channel");
checkPerChannelParamsSize(rtensor, axis, scales, zero_points);

dequantize_tensor_per_channel_affine_stub(
qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
Expand Down Expand Up @@ -297,13 +294,7 @@ Tensor& dequantize_tensor_per_channel_float_qparams(
" Expected: [0, ",
qtensor.dim(),
")");
int64_t channel = qtensor.size(axis);
TORCH_CHECK(
channel == int64_t(scales.numel()),
"length of scales must equal to channel");
TORCH_CHECK(
channel == int64_t(zero_points.numel()),
"length of zero_points must equal to channel");
checkPerChannelParamsSize(rtensor, axis, scales, zero_points);

dequantize_tensor_per_channel_float_qparams_stub(
qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
Expand Down

0 comments on commit 128faf2

Please sign in to comment.