forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add bits tensor types (pytorch#88594)
TODO (in later PRs) - [ ] the other bits8, 4x2, 2x4, 1x8 - [ ] bits printer function Pull Request resolved: pytorch#88594 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
22e7514
commit f3b1315
Showing
6 changed files
with
223 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch.utils._mode_utils import no_dispatch | ||
from torch.utils._pytree import tree_map | ||
|
||
class TensorSubclassDemo(torch.Tensor): | ||
def __new__(cls, elem): | ||
assert elem.dtype == torch.bits16 | ||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) | ||
|
||
def __init__(self, elem): | ||
super().__init__() | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
def unwrap(t): | ||
if isinstance(t, torch.Tensor): | ||
with no_dispatch(): | ||
return t.view(torch.int16) | ||
return t | ||
|
||
args = tree_map(unwrap, args) | ||
kwargs = tree_map(unwrap, kwargs) | ||
with no_dispatch(): | ||
out = func(*args, **kwargs) | ||
return out.view(torch.bits16) | ||
|
||
def __repr__(self) -> str: | ||
with no_dispatch(): | ||
return f"TensorSubclassDemo{self.view(torch.int16)}" | ||
|
||
|
||
class TestBits16(TestCase): | ||
def test(self): | ||
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) | ||
_ = torch.empty(20, dtype=torch.bits16) | ||
|
||
s = TensorSubclassDemo(t) | ||
s = s + 1 | ||
|
||
|
||
if __name__ == '__main__': | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch.utils._mode_utils import no_dispatch | ||
from torch.utils._pytree import tree_map | ||
|
||
class Int16Tensor(torch.Tensor): | ||
def __new__(cls, elem): | ||
assert elem.dtype == torch.bits16 | ||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) | ||
|
||
def __init__(self, elem): | ||
super().__init__() | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
def unwrap(t): | ||
if isinstance(t, torch.Tensor): | ||
with no_dispatch(): | ||
return t.view(torch.int16) | ||
return t | ||
args = tree_map(unwrap, args) | ||
kwargs = tree_map(unwrap, kwargs) | ||
|
||
with no_dispatch(): | ||
out = func(*args, **kwargs) | ||
|
||
def wrap(t): | ||
if isinstance(t, torch.Tensor): | ||
with no_dispatch(): | ||
return t.view(torch.bits16) | ||
return t | ||
out = tree_map(wrap, out) | ||
return out | ||
|
||
def __repr__(self) -> str: | ||
with no_dispatch(): | ||
t16 = self.view(torch.int16) | ||
return f"TensorSubclassDemo{self.view(torch.int16)}" | ||
|
||
|
||
class TestBits(TestCase): | ||
def test_types(self): | ||
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16] | ||
for bits_type in bits_types: | ||
_ = torch.zeros(20, dtype=torch.int32).view(bits_type) | ||
_ = torch.empty(20, dtype=bits_type) | ||
|
||
def test_subclass(self): | ||
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) | ||
s = Int16Tensor(t) | ||
s = s + 1 - 1 | ||
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16))) | ||
|
||
|
||
if __name__ == '__main__': | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#pragma once | ||
#include <cstdint> | ||
|
||
#include <c10/macros/Macros.h> | ||
|
||
namespace c10 { | ||
|
||
/** | ||
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits1x8 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits1x8() = default; | ||
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits2x4 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits2x4() = default; | ||
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits4x2 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits4x2() = default; | ||
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any | ||
* semantics defined. | ||
*/ | ||
struct alignas(1) bits8 { | ||
uint8_t val_; | ||
bits8() = default; | ||
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any | ||
* semantics defined. | ||
*/ | ||
struct alignas(2) bits16 { | ||
uint16_t val_; | ||
bits16() = default; | ||
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} | ||
}; | ||
|
||
} // namespace c10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters