Skip to content

Commit

Permalink
Adding quantized clamp kernel (pytorch#30541)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#30541

ghstack-source-id: 95450749

Adding quantized clamp kernel

Test Plan:
Added test.

buck test mode/dev //caffe2/test:quantized -- 'test_qclamp \(test_quantized\.TestQuantizedOps\)' --print-passing-details

Differential Revision: D18739628

fbshipit-source-id: 38a029ab96c5b0689bb15c67dc4f274883e74975
  • Loading branch information
dskhudia authored and facebook-github-bot committed Dec 12, 2019
1 parent 1d5af95 commit a2463cb
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 0 deletions.
78 changes: 78 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,24 @@ struct Vec256<c10::qint8> {
#endif
}

Vec256<c10::qint8> minimum(Vec256<c10::qint8> b) const {
#ifdef __AVX2__
return _mm256_min_epi8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
int8_t int_vals[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
int8_t b_vals[size()];
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
int8_t result_vals[size()];
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
}

Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) const {
return maximum(zero_point);
}
Expand Down Expand Up @@ -435,6 +453,24 @@ struct Vec256<c10::quint8> {
#endif
}

Vec256<c10::quint8> minimum(Vec256<c10::quint8> b) const {
#ifdef __AVX2__
return _mm256_min_epu8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
uint8_t int_vals[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
uint8_t b_vals[size()];
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
uint8_t result_vals[size()];
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<uint8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
}

Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) const {
return maximum(zero_point);
}
Expand Down Expand Up @@ -562,6 +598,24 @@ struct Vec256<c10::qint32> {
#endif
}

Vec256<c10::qint32> minimum(Vec256<c10::qint32> b) const {
#ifdef __AVX2__
return _mm256_min_epi32(vals, b.vals);
#else
// Pray the compiler can autovectorize this
int32_t int_vals[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
int32_t b_vals[size()];
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
int32_t result_vals[size()];
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int32_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
}

Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) const {
return maximum(zero_point);
}
Expand Down Expand Up @@ -722,6 +776,14 @@ struct Vec256<c10::qint8> : public Vec256QuantizedConverter<
return retval;
}

Vec256<c10::qint8> minimum(Vec256<c10::qint8> b) const {
Vec256<c10::qint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}

Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) const {
return maximum(zero_point);
}
Expand Down Expand Up @@ -792,6 +854,14 @@ struct Vec256<c10::quint8> : public Vec256QuantizedConverter<
return retval;
}

Vec256<c10::quint8> minimum(Vec256<c10::quint8> b) const {
Vec256<c10::quint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}

Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) const {
return maximum(zero_point);
}
Expand Down Expand Up @@ -863,6 +933,14 @@ struct Vec256<c10::qint32> : public Vec256QuantizedConverter<
return retval;
}

Vec256<c10::qint32> minimum(Vec256<c10::qint32> b) const {
Vec256<c10::qint32> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}

Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) const {
return maximum(zero_point);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,10 @@
use_c10_dispatcher: full
supports_named_tensor: True
variants: function, method
dispatch:
CPU: clamp
CUDA: clamp
QuantizedCPU: quantized_clamp

- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
supports_named_tensor: True
Expand Down
37 changes: 37 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,42 @@ void qrelu6_kernel(const Tensor& qx, Tensor& qy) {
});
}

void qclamp_kernel(
const Tensor& qx,
Scalar min_scalar,
Scalar max_scalar,
Tensor& qy) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
at::device(kCPU).dtype(SCALAR_TYPE),
qx.q_scale(),
qx.q_zero_point(),
qx.suggest_memory_format());
using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qy, qx);
auto min = min_scalar.to<float>();
auto max = max_scalar.to<float>();
scalar_t min_q =
at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), min);
scalar_t max_q =
at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), max);
auto min_vec = Vec(min_q);
auto max_vec = Vec(max_q);
cpu_kernel_vec(
iter,
[&](scalar_t value) -> scalar_t {
underlying_t min_clamped =
std::max<underlying_t>(value.val_, min_q.val_);
return scalar_t(std::min<underlying_t>(min_clamped, max_q.val_));
},
[&](Vec val) -> Vec {
auto min_clamped = val.maximum(min_vec);
return min_clamped.minimum(max_vec);
});
});
}

// Note: out is assumed to be the same size as self and other.
// Note: Addition is only supported when self, other, out are of the same dtype.
template <bool ReLUFused = false>
Expand Down Expand Up @@ -814,6 +850,7 @@ void qtopk_kernel(Tensor& values,

REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
REGISTER_DISPATCH(qrelu6_stub, &qrelu6_kernel);
REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel);
Expand Down
60 changes: 60 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qclamp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/quantized/Quantizer.h>

#include <algorithm>

namespace at {
namespace native {

DEFINE_DISPATCH(qclamp_stub);

namespace {
Tensor quantized_clamp_impl(
const Tensor& qx,
optional<Scalar> min,
optional<Scalar> max) {
Tensor qy;
if (min && max) {
qclamp_stub(qx.device().type(), qx, *min, *max, qy);
} else {
TORCH_CHECK(
false, "Both min and max should be specifed for quantized clamp!");
}
return qy;
}
} // namespace

// at::native functions for the native_functions.yaml
Tensor quantized_clamp(
const Tensor& qx,
optional<Scalar> min,
optional<Scalar> max) {
Tensor qy;
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "clamp", [&]() {
qy = quantized_clamp_impl(qx, min, max);
});
return qy;
}

// Keep the registry in the anonymous namespace.
namespace {
class QClamp final : public c10::OperatorKernel {
public:
Tensor operator()(Tensor qx, optional<Scalar> min, optional<Scalar> max) {
return quantized_clamp(qx, min, max);
}
};

static auto registry = c10::RegisterOperators().op(
"quantized::clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy",
c10::RegisterOperators::options().kernel<QClamp>(
TensorTypeId::QuantizedCPUTensorId));
} // namespace

} // namespace native
} // namespace at
6 changes: 6 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quantized_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ namespace at {
namespace native {

using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qclamp_fn = void (*)(
const at::Tensor& /*qx*/,
Scalar min,
Scalar max,
at::Tensor& /*qy*/);
using qadd_fn =
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
using qmaxpool_2d_fn = void (*)(
Expand Down Expand Up @@ -79,6 +84,7 @@ using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool
// using qavg_pool2d_fn
DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
DECLARE_DISPATCH(qrelu_fn, qrelu6_stub);
DECLARE_DISPATCH(qclamp_fn, qclamp_stub);
DECLARE_DISPATCH(qadd_fn, qadd_stub);
DECLARE_DISPATCH(qadd_fn, qadd_relu_stub);
DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
Expand Down
27 changes: 27 additions & 0 deletions test/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ def test_qrelu6(self, X):
self.assertEqual(qY, qY_hat,
message="{} relu failed".format(name))

"""Tests the correctness of the quantized::clamp op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8),
elements=st.floats(-1e6, 1e6, allow_nan=False),
qparams=hu.qparams()),
min_val=st.floats(-1e6, 1e6, allow_nan=False),
max_val=st.floats(-1e6, 1e6, allow_nan=False))
def test_qclamp(self, X, min_val, max_val):
X, (scale, zero_point, torch_type) = X

assume(min_val <= max_val)
Y = X.copy()
Y[Y < min_val] = min_val
Y[Y > max_val] = max_val
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)

ops_under_test = {
'ops.quantized': torch.ops.quantized.clamp,
}

for name, op in ops_under_test.items():
qY_hat = op(qX, min_val, max_val)
self.assertEqual(qY, qY_hat, message="{} qclamp failed".format(name))

"""Tests the correctness of the scalar addition."""
@given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5),
elements=st.floats(-1e6, 1e6, allow_nan=False),
Expand Down
16 changes: 16 additions & 0 deletions torch/nn/quantized/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ def relu(input, inplace=False):
else:
return torch.relu(input)

def clamp(input, min_, max_):
# type: (Tensor, float, float) -> Tensor
r"""float(input, min_, max_) -> Tensor
Applies the clamp function element-wise.
See :class:`~torch.nn.quantized.clamp` for more details.
Args:
input: quantized input
min_: minimum value for clamping
max_: maximum value for clamping
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.clamp' must be quantized!")
return torch.clamp(input, min_, max_)

def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
r"""Upsamples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
Expand Down

0 comments on commit a2463cb

Please sign in to comment.