forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add various uninterpreted bit tensor data types (pytorch#94992)
Summary: This PR adds a set of unintrepreted data types on PyTorch which can be used to implement experimental functionality out of core (think fp8, int4, int16 quant, etc). Note: this is a copy-pasta of pytorch#89990 with a bug fix for clang9, easier to just to put up another PR since I'm not sure how comandeering works with Meta-only changes. @bypass-github-export-checks Test Plan: ``` python test/test_quantization.py -k TestBits ``` Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#94992 Approved by: https://github.com/angelayi
- Loading branch information
1 parent
e44737e
commit 9dbfca7
Showing
6 changed files
with
174 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#pragma once | ||
#include <cstdint> | ||
|
||
#include <c10/macros/Macros.h> | ||
|
||
namespace c10 { | ||
|
||
/** | ||
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits1x8 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits1x8() = default; | ||
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits2x4 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits2x4() = default; | ||
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte | ||
* boundary), without any semantics defined. | ||
*/ | ||
struct alignas(1) bits4x2 { | ||
using underlying = uint8_t; | ||
uint8_t val_; | ||
bits4x2() = default; | ||
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any | ||
* semantics defined. | ||
*/ | ||
struct alignas(1) bits8 { | ||
uint8_t val_; | ||
bits8() = default; | ||
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} | ||
}; | ||
|
||
/** | ||
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any | ||
* semantics defined. | ||
*/ | ||
struct alignas(2) bits16 { | ||
uint16_t val_; | ||
bits16() = default; | ||
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} | ||
}; | ||
|
||
} // namespace c10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Owner(s): ["oncall: quantization"] | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch.utils._mode_utils import no_dispatch | ||
from torch.utils._pytree import tree_map | ||
|
||
class Int16Tensor(torch.Tensor): | ||
def __new__(cls, elem): | ||
assert elem.dtype == torch.bits16 | ||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) | ||
|
||
def __init__(self, elem): | ||
super().__init__() | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
def unwrap(t): | ||
if isinstance(t, torch.Tensor): | ||
with no_dispatch(): | ||
return t.view(torch.int16) | ||
return t | ||
args = tree_map(unwrap, args) | ||
kwargs = tree_map(unwrap, kwargs) | ||
|
||
with no_dispatch(): | ||
out = func(*args, **kwargs) | ||
|
||
def wrap(t): | ||
if isinstance(t, torch.Tensor): | ||
with no_dispatch(): | ||
return t.view(torch.bits16) | ||
return t | ||
out = tree_map(wrap, out) | ||
return out | ||
|
||
def __repr__(self) -> str: | ||
with no_dispatch(): | ||
t16 = self.view(torch.int16) | ||
return f"TensorSubclassDemo{self.view(torch.int16)}" | ||
|
||
|
||
class TestBits(TestCase): | ||
def test_types(self): | ||
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16] | ||
for bits_type in bits_types: | ||
_ = torch.zeros(20, dtype=torch.int32).view(bits_type) | ||
_ = torch.empty(20, dtype=bits_type) | ||
|
||
def test_subclass(self): | ||
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16) | ||
s = Int16Tensor(t) | ||
s = s + 1 - 1 | ||
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16))) | ||
|
||
|
||
if __name__ == '__main__': | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters