Skip to content

Commit

Permalink
Add bits tensor types (pytorch#88594)
Browse files Browse the repository at this point in the history
TODO (in later PRs)
- [ ] the other bits8, 4x2, 2x4, 1x8
- [ ] bits printer function
Pull Request resolved: pytorch#88594
Approved by: https://github.com/ezyang
  • Loading branch information
angelayi authored and pytorchmergebot committed Nov 28, 2022
1 parent 22e7514 commit f3b1315
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 18 deletions.
7 changes: 7 additions & 0 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
Expand Down
64 changes: 46 additions & 18 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
Expand Down Expand Up @@ -43,7 +44,12 @@ namespace c10 {
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */

// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
Expand Down Expand Up @@ -272,6 +278,12 @@ static inline bool isQIntType(ScalarType t) {
t == ScalarType::QUInt2x4;
}

static inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
t == ScalarType::Bits16;
}

static inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
Expand Down Expand Up @@ -309,6 +321,12 @@ static inline bool isSignedType(ScalarType t) {
return std::numeric_limits<ctype>::is_signed;

switch (t) {
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bits types are undefined");
case ScalarType::ComplexHalf:
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
Expand Down Expand Up @@ -423,28 +441,38 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
toString(b));
}

if (isBitsType(a) && a == b) {
return a;
} else if (isBitsType(a) || isBitsType(b)) {
return ScalarType::Undefined;
}

// this matrix has to be consistent with
// AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
// are not sure about the correct value for type promotion.
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
// clang-format off
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf q4 q5*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, ud, ud},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, ud, ud},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, ud, ud},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, ud, ud},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, ud, ud},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, ud, ud},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, ud, ud},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, ud, ud},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, ud, ud},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, ud, ud},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, ud, ud},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, ud, ud},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, ud, ud},
/* q4 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q5 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
// clang-format on
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
Expand Down
43 changes: 43 additions & 0 deletions c10/test/util/bits16_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map

class TensorSubclassDemo(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

def __init__(self, elem):
super().__init__()

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, torch.Tensor):
with no_dispatch():
return t.view(torch.int16)
return t

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
with no_dispatch():
out = func(*args, **kwargs)
return out.view(torch.bits16)

def __repr__(self) -> str:
with no_dispatch():
return f"TensorSubclassDemo{self.view(torch.int16)}"


class TestBits16(TestCase):
def test(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
_ = torch.empty(20, dtype=torch.bits16)

s = TensorSubclassDemo(t)
s = s + 1


if __name__ == '__main__':
run_tests()
56 changes: 56 additions & 0 deletions c10/test/util/bits_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map

class Int16Tensor(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

def __init__(self, elem):
super().__init__()

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, torch.Tensor):
with no_dispatch():
return t.view(torch.int16)
return t
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

with no_dispatch():
out = func(*args, **kwargs)

def wrap(t):
if isinstance(t, torch.Tensor):
with no_dispatch():
return t.view(torch.bits16)
return t
out = tree_map(wrap, out)
return out

def __repr__(self) -> str:
with no_dispatch():
t16 = self.view(torch.int16)
return f"TensorSubclassDemo{self.view(torch.int16)}"


class TestBits(TestCase):
def test_types(self):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
_ = torch.empty(20, dtype=bits_type)

def test_subclass(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
s = Int16Tensor(t)
s = s + 1 - 1
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))


if __name__ == '__main__':
run_tests()
61 changes: 61 additions & 0 deletions c10/util/bits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once
#include <cstdint>

#include <c10/macros/Macros.h>

namespace c10 {

/**
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits1x8 {
using underlying = uint8_t;
uint8_t val_;
bits1x8() = default;
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {}
};

/**
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits2x4 {
using underlying = uint8_t;
uint8_t val_;
bits2x4() = default;
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {}
};

/**
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits4x2 {
using underlying = uint8_t;
uint8_t val_;
bits4x2() = default;
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {}
};

/**
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any
* semantics defined.
*/
struct alignas(1) bits8 {
uint8_t val_;
bits8() = default;
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {}
};

/**
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any
* semantics defined.
*/
struct alignas(2) bits16 {
uint16_t val_;
bits16() = default;
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {}
};

} // namespace c10
10 changes: 10 additions & 0 deletions torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
return std::make_pair("quint4x2", "");
case at::ScalarType::QUInt2x4:
return std::make_pair("quint2x4", "");
case at::ScalarType::Bits1x8:
return std::make_pair("bits1x8", "");
case at::ScalarType::Bits2x4:
return std::make_pair("bits2x4", "");
case at::ScalarType::Bits4x2:
return std::make_pair("bits4x2", "");
case at::ScalarType::Bits8:
return std::make_pair("bits8", "");
case at::ScalarType::Bits16:
return std::make_pair("bits16", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
Expand Down

0 comments on commit f3b1315

Please sign in to comment.