diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 542adb9698176..614dc46158e8f 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -60,13 +60,6 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::QUInt2x4: TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); break; - case ScalarType::Bits1x8: - case ScalarType::Bits2x4: - case ScalarType::Bits4x2: - case ScalarType::Bits8: - case ScalarType::Bits16: - TORCH_CHECK(false, "Bit types are not supported by dlpack"); - break; case ScalarType::Undefined: TORCH_CHECK(false, "Undefined is not a valid ScalarType"); case ScalarType::NumOptions: diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 2fa3c9ceb4ea4..51de905def9c1 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -44,12 +43,7 @@ namespace c10 { _(c10::qint32, QInt32) /* 14 */ \ _(at::BFloat16, BFloat16) /* 15 */ \ _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ \ - _(c10::bits1x8, Bits1x8) /* 18 */ \ - _(c10::bits2x4, Bits2x4) /* 19 */ \ - _(c10::bits4x2, Bits4x2) /* 20 */ \ - _(c10::bits8, Bits8) /* 21 */ \ - _(c10::bits16, Bits16) /* 22 */ + _(c10::quint2x4, QUInt2x4) /* 17 */ // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() @@ -278,12 +272,6 @@ static inline bool isQIntType(ScalarType t) { t == ScalarType::QUInt2x4; } -static inline bool isBitsType(ScalarType t) { - return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || - t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || - t == ScalarType::Bits16; -} - static inline ScalarType toQIntType(ScalarType t) { switch (t) { case ScalarType::Byte: @@ -321,12 +309,6 @@ static inline bool isSignedType(ScalarType t) { return std::numeric_limits::is_signed; switch (t) { - case ScalarType::Bits1x8: - case ScalarType::Bits2x4: - case ScalarType::Bits4x2: - case ScalarType::Bits8: - case ScalarType::Bits16: - TORCH_CHECK(false, "Bits types are undefined"); case ScalarType::ComplexHalf: case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: @@ -441,38 +423,28 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { toString(b)); } - if (isBitsType(a) && a == b) { - return a; - } else if (isBitsType(a) || isBitsType(b)) { - return ScalarType::Undefined; - } - // this matrix has to be consistent with // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we // are not sure about the correct value for type promotion. static constexpr ScalarType _promoteTypesLookup[static_cast( ScalarType::NumOptions)][static_cast(ScalarType::NumOptions)] = { - // clang-format off - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf q4 q5*/ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, ud, ud}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, ud, ud}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, ud, ud}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, ud, ud}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, ud, ud}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, ud, ud}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, ud, ud}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, ud, ud}, - /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, ud, ud}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, ud, ud}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, ud, ud}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, ud, ud}, - /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, ud, ud}, - /* q4 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q5 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - // clang-format on + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf}, + /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf}, }; return _promoteTypesLookup[static_cast(a)][static_cast(b)]; } diff --git a/c10/test/util/bits16_test.py b/c10/test/util/bits16_test.py deleted file mode 100644 index 97a8220f16fc8..0000000000000 --- a/c10/test/util/bits16_test.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from torch.testing._internal.common_utils import run_tests, TestCase -from torch.utils._mode_utils import no_dispatch -from torch.utils._pytree import tree_map - -class TensorSubclassDemo(torch.Tensor): - def __new__(cls, elem): - assert elem.dtype == torch.bits16 - return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - - def __init__(self, elem): - super().__init__() - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(t): - if isinstance(t, torch.Tensor): - with no_dispatch(): - return t.view(torch.int16) - return t - - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - with no_dispatch(): - out = func(*args, **kwargs) - return out.view(torch.bits16) - - def __repr__(self) -> str: - with no_dispatch(): - return f"TensorSubclassDemo{self.view(torch.int16)}" - - -class TestBits16(TestCase): - def test(self): - t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) - _ = torch.empty(20, dtype=torch.bits16) - - s = TensorSubclassDemo(t) - s = s + 1 - - -if __name__ == '__main__': - run_tests() diff --git a/c10/test/util/bits_test.py b/c10/test/util/bits_test.py deleted file mode 100644 index c87c8428b29a1..0000000000000 --- a/c10/test/util/bits_test.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -from torch.testing._internal.common_utils import run_tests, TestCase -from torch.utils._mode_utils import no_dispatch -from torch.utils._pytree import tree_map - -class Int16Tensor(torch.Tensor): - def __new__(cls, elem): - assert elem.dtype == torch.bits16 - return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - - def __init__(self, elem): - super().__init__() - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(t): - if isinstance(t, torch.Tensor): - with no_dispatch(): - return t.view(torch.int16) - return t - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - - with no_dispatch(): - out = func(*args, **kwargs) - - def wrap(t): - if isinstance(t, torch.Tensor): - with no_dispatch(): - return t.view(torch.bits16) - return t - out = tree_map(wrap, out) - return out - - def __repr__(self) -> str: - with no_dispatch(): - t16 = self.view(torch.int16) - return f"TensorSubclassDemo{self.view(torch.int16)}" - - -class TestBits(TestCase): - def test_types(self): - bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16] - for bits_type in bits_types: - _ = torch.zeros(20, dtype=torch.int32).view(bits_type) - _ = torch.empty(20, dtype=bits_type) - - def test_subclass(self): - t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) - s = Int16Tensor(t) - s = s + 1 - 1 - self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16))) - - -if __name__ == '__main__': - run_tests() diff --git a/c10/util/bits.h b/c10/util/bits.h deleted file mode 100644 index 89abf454791ef..0000000000000 --- a/c10/util/bits.h +++ /dev/null @@ -1,61 +0,0 @@ -#pragma once -#include - -#include - -namespace c10 { - -/** - * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte - * boundary), without any semantics defined. - */ -struct alignas(1) bits1x8 { - using underlying = uint8_t; - uint8_t val_; - bits1x8() = default; - C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} -}; - -/** - * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte - * boundary), without any semantics defined. - */ -struct alignas(1) bits2x4 { - using underlying = uint8_t; - uint8_t val_; - bits2x4() = default; - C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} -}; - -/** - * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte - * boundary), without any semantics defined. - */ -struct alignas(1) bits4x2 { - using underlying = uint8_t; - uint8_t val_; - bits4x2() = default; - C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} -}; - -/** - * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any - * semantics defined. - */ -struct alignas(1) bits8 { - uint8_t val_; - bits8() = default; - C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} -}; - -/** - * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any - * semantics defined. - */ -struct alignas(2) bits16 { - uint16_t val_; - bits16() = default; - C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} -}; - -} // namespace c10 diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 07ed3297d557d..3e0e3acf38c29 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -52,16 +52,6 @@ std::pair getDtypeNames(at::ScalarType scalarType) { return std::make_pair("quint4x2", ""); case at::ScalarType::QUInt2x4: return std::make_pair("quint2x4", ""); - case at::ScalarType::Bits1x8: - return std::make_pair("bits1x8", ""); - case at::ScalarType::Bits2x4: - return std::make_pair("bits2x4", ""); - case at::ScalarType::Bits4x2: - return std::make_pair("bits4x2", ""); - case at::ScalarType::Bits8: - return std::make_pair("bits8", ""); - case at::ScalarType::Bits16: - return std::make_pair("bits16", ""); default: throw std::runtime_error("Unimplemented scalar type"); }