Skip to content

Commit

Permalink
Revert "Add bits tensor types (pytorch#88594)"
Browse files Browse the repository at this point in the history
This reverts commit f3b1315.

Reverted pytorch#88594 on behalf of https://github.com/jeanschmidt due to breaking internal builds
  • Loading branch information
pytorchmergebot committed Nov 30, 2022
1 parent 296e1ba commit 4cc5be3
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 223 deletions.
7 changes: 0 additions & 7 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
Expand Down
64 changes: 18 additions & 46 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
Expand Down Expand Up @@ -44,12 +43,7 @@ namespace c10 {
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */
_(c10::quint2x4, QUInt2x4) /* 17 */

// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
Expand Down Expand Up @@ -278,12 +272,6 @@ static inline bool isQIntType(ScalarType t) {
t == ScalarType::QUInt2x4;
}

static inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
t == ScalarType::Bits16;
}

static inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
Expand Down Expand Up @@ -321,12 +309,6 @@ static inline bool isSignedType(ScalarType t) {
return std::numeric_limits<ctype>::is_signed;

switch (t) {
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bits types are undefined");
case ScalarType::ComplexHalf:
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
Expand Down Expand Up @@ -441,38 +423,28 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
toString(b));
}

if (isBitsType(a) && a == b) {
return a;
} else if (isBitsType(a) || isBitsType(b)) {
return ScalarType::Undefined;
}

// this matrix has to be consistent with
// AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
// are not sure about the correct value for type promotion.
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
// clang-format off
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf q4 q5*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, ud, ud},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, ud, ud},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, ud, ud},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, ud, ud},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, ud, ud},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, ud, ud},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, ud, ud},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, ud, ud},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, ud, ud},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, ud, ud},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, ud, ud},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, ud, ud},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, ud, ud},
/* q4 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q5 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
// clang-format on
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
Expand Down
43 changes: 0 additions & 43 deletions c10/test/util/bits16_test.py

This file was deleted.

56 changes: 0 additions & 56 deletions c10/test/util/bits_test.py

This file was deleted.

61 changes: 0 additions & 61 deletions c10/util/bits.h

This file was deleted.

10 changes: 0 additions & 10 deletions torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,6 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
return std::make_pair("quint4x2", "");
case at::ScalarType::QUInt2x4:
return std::make_pair("quint2x4", "");
case at::ScalarType::Bits1x8:
return std::make_pair("bits1x8", "");
case at::ScalarType::Bits2x4:
return std::make_pair("bits2x4", "");
case at::ScalarType::Bits4x2:
return std::make_pair("bits4x2", "");
case at::ScalarType::Bits8:
return std::make_pair("bits8", "");
case at::ScalarType::Bits16:
return std::make_pair("bits16", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
Expand Down

0 comments on commit 4cc5be3

Please sign in to comment.