diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index d795d3db44a17..928b206526bf5 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -60,6 +60,13 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::QUInt2x4: TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); break; + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + TORCH_CHECK(false, "Bit types are not supported by dlpack"); + break; case ScalarType::Undefined: TORCH_CHECK(false, "Undefined is not a valid ScalarType"); case ScalarType::NumOptions: diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 5fa2f4cd6e457..31aac7b2f7ce1 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,12 @@ namespace c10 { _(c10::qint32, QInt32) /* 14 */ \ _(at::BFloat16, BFloat16) /* 15 */ \ _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() @@ -270,6 +276,12 @@ static inline bool isQIntType(ScalarType t) { t == ScalarType::QUInt2x4; } +static inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + static inline ScalarType toQIntType(ScalarType t) { switch (t) { case ScalarType::Byte: @@ -307,6 +319,12 @@ static inline bool isSignedType(ScalarType t) { return std::numeric_limits::is_signed; switch (t) { + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + TORCH_CHECK(false, "Bits types are undefined"); case ScalarType::ComplexHalf: case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: @@ -421,11 +439,24 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { toString(b)); } + if (isBitsType(a) && a == b) { + return a; + } else if (isBitsType(a) || isBitsType(b)) { + return ScalarType::Undefined; + } + + // Ignore the 5 bits types, since they are handled by the if statement + // above and do not participate in type promotion. The `5` value has to + // be consistent with the number of the unique `c10::bits*` types that + // exist. + const int NUM_PROMOTE_TYPES = static_cast(ScalarType::NumOptions) - 5; + // this matrix has to be consistent with // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we // are not sure about the correct value for type promotion. - static constexpr ScalarType _promoteTypesLookup[static_cast( - ScalarType::NumOptions)][static_cast(ScalarType::NumOptions)] = { + // clang-format off + static constexpr ScalarType _promoteTypesLookup[ + NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/ /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf}, /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf}, @@ -444,6 +475,7 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf}, }; + // clang-format on return _promoteTypesLookup[static_cast(a)][static_cast(b)]; } diff --git a/c10/util/bits.h b/c10/util/bits.h new file mode 100644 index 0000000000000..89abf454791ef --- /dev/null +++ b/c10/util/bits.h @@ -0,0 +1,61 @@ +#pragma once +#include + +#include + +namespace c10 { + +/** + * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits1x8 { + using underlying = uint8_t; + uint8_t val_; + bits1x8() = default; + C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} +}; + +/** + * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits2x4 { + using underlying = uint8_t; + uint8_t val_; + bits2x4() = default; + C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} +}; + +/** + * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits4x2 { + using underlying = uint8_t; + uint8_t val_; + bits4x2() = default; + C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} +}; + +/** + * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any + * semantics defined. + */ +struct alignas(1) bits8 { + uint8_t val_; + bits8() = default; + C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} +}; + +/** + * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any + * semantics defined. + */ +struct alignas(2) bits16 { + uint16_t val_; + bits16() = default; + C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} +}; + +} // namespace c10 diff --git a/test/quantization/core/experimental/test_bits.py b/test/quantization/core/experimental/test_bits.py new file mode 100644 index 0000000000000..895ad61009ec7 --- /dev/null +++ b/test/quantization/core/experimental/test_bits.py @@ -0,0 +1,58 @@ +# Owner(s): ["oncall: quantization"] + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils._mode_utils import no_dispatch +from torch.utils._pytree import tree_map + +class Int16Tensor(torch.Tensor): + def __new__(cls, elem): + assert elem.dtype == torch.bits16 + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + def __init__(self, elem): + super().__init__() + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(t): + if isinstance(t, torch.Tensor): + with no_dispatch(): + return t.view(torch.int16) + return t + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + with no_dispatch(): + out = func(*args, **kwargs) + + def wrap(t): + if isinstance(t, torch.Tensor): + with no_dispatch(): + return t.view(torch.bits16) + return t + out = tree_map(wrap, out) + return out + + def __repr__(self) -> str: + with no_dispatch(): + t16 = self.view(torch.int16) + return f"TensorSubclassDemo{self.view(torch.int16)}" + + +class TestBits(TestCase): + def test_types(self): + bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16] + for bits_type in bits_types: + _ = torch.zeros(20, dtype=torch.int32).view(bits_type) + _ = torch.empty(20, dtype=bits_type) + + def test_subclass(self): + t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) + s = Int16Tensor(t) + s = s + 1 - 1 + self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16))) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_quantization.py b/test/test_quantization.py index 842009aeb55e2..48fe750bb3282 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -134,5 +134,8 @@ except ImportError: pass +# Experimental functionality +from quantization.core.experimental.test_bits import TestBits # noqa: F401 + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index fd9a6b26a4b2d..84d7566a8c339 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -52,6 +52,16 @@ std::pair getDtypeNames(at::ScalarType scalarType) { return std::make_pair("quint4x2", ""); case at::ScalarType::QUInt2x4: return std::make_pair("quint2x4", ""); + case at::ScalarType::Bits1x8: + return std::make_pair("bits1x8", ""); + case at::ScalarType::Bits2x4: + return std::make_pair("bits2x4", ""); + case at::ScalarType::Bits4x2: + return std::make_pair("bits4x2", ""); + case at::ScalarType::Bits8: + return std::make_pair("bits8", ""); + case at::ScalarType::Bits16: + return std::make_pair("bits16", ""); default: throw std::runtime_error("Unimplemented scalar type"); }