Skip to content

Commit

Permalink
Rename _Typed/_UntypedStorage to Typed/UntypedStorage and update …
Browse files Browse the repository at this point in the history
…docs (pytorch#82438)

### Description

Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.

`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.

Documentation for storages is improved as well.

### Issue
Fixes pytorch#82436

### Testing
N/A

Pull Request resolved: pytorch#82438
Approved by: https://github.com/ezyang
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Jul 30, 2022
1 parent 28304dd commit 14d0296
Show file tree
Hide file tree
Showing 26 changed files with 207 additions and 199 deletions.
37 changes: 21 additions & 16 deletions docs/source/storage.rst
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
torch.Storage
===================================

A :class:`torch._TypedStorage` is a contiguous, one-dimensional array of
:class:`torch.Storage` is an alias for the storage class that corresponds with
the default data type (:func:`torch.get_default_dtype()`). For instance, if the
default data type is :attr:`torch.float`, :class:`torch.Storage` resolves to
:class:`torch.FloatStorage`.

The :class:`torch.<type>Storage` and :class:`torch.cuda.<type>Storage` classes,
like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc., are not
actually ever instantiated. Calling their constructors creates
a :class:`torch.TypedStorage` with the appropriate :class:`torch.dtype` and
:class:`torch.device`. :class:`torch.<type>Storage` classes have all of the
same class methods that :class:`torch.TypedStorage` has.

A :class:`torch.TypedStorage` is a contiguous, one-dimensional array of
elements of a particular :class:`torch.dtype`. It can be given any
:class:`torch.dtype`, and the internal data will be interpretted appropriately.
:class:`torch.TypedStorage` contains a :class:`torch.UntypedStorage` which
holds the data as an untyped array of bytes.

Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`,
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.

For backward compatibility, there are also :class:`torch.<type>Storage` classes
(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These
classes are not actually instantiated, and calling their constructors creates
a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`.
:class:`torch.<type>Storage` classes have all of the same class methods that
:class:`torch._TypedStorage` has.

Also for backward compatibility, :class:`torch.Storage` is an alias for the
storage class that corresponds with the default data type
(:func:`torch.get_default_dtype()`). For instance, if the default data type is
:attr:`torch.float`, :class:`torch.Storage` resolves to
:class:`torch.FloatStorage`.

.. autoclass:: torch.TypedStorage
:members:
:undoc-members:
:inherited-members:

.. autoclass:: torch._TypedStorage
.. autoclass:: torch.UntypedStorage
:members:
:undoc-members:
:inherited-members:
Expand Down
2 changes: 1 addition & 1 deletion test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@
"QUInt4x2Storage",
"QUInt8Storage",
"Storage",
"_TypedStorage",
"TypedStorage",
"_adaptive_avg_pool2d",
"_adaptive_avg_pool3d",
"_add_batch_dim",
Expand Down
4 changes: 2 additions & 2 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,8 @@ def test_serialization_array_with_storage(self):
self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

Expand Down
12 changes: 6 additions & 6 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _test_serialization_assert(self, b, c):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertEqual(c[4].dtype, torch.float)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_serialization_backwards_compat(self):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertEqual(c[4].dtype, torch.float32)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
Expand Down Expand Up @@ -621,8 +621,8 @@ def save_load_check(a, b):
a = torch.tensor([], dtype=dtype, device=device)

for other_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
s = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
s = torch.TypedStorage(
wrap_storage=a.storage().untyped(),
dtype=other_dtype)
save_load_check(a, s)
save_load_check(a.storage(), s)
Expand Down Expand Up @@ -653,8 +653,8 @@ def test_save_different_dtype_error(self):
torch.save([a.storage(), a.imag.storage()], f)

a = torch.randn(10, device=device)
s_bytes = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
s_bytes = torch.TypedStorage(
wrap_storage=a.storage().untyped(),
dtype=torch.uint8)

with self.assertRaisesRegex(RuntimeError, error_msg):
Expand Down
64 changes: 32 additions & 32 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def rand_byte():
for i in range(10):
bytes_list = [rand_byte() for _ in range(element_size)]
scalar = bytes_to_scalar(bytes_list, dtype, device)
self.assertEqual(scalar.storage()._untyped().tolist(), bytes_list)
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)

@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
torch.bool, torch.float32, torch.complex64, torch.float64,
Expand All @@ -175,7 +175,7 @@ def test_storage(self, device, dtype):
v_s[el_num],
v[dim0][dim1])

v_s_byte = v.storage()._untyped()
v_s_byte = v.storage().untyped()
el_size = v.element_size()

for el_num in range(v.numel()):
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_tensor_from_storage(self, device, dtype):
a_s = a.storage()
b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
self.assertEqual(a, b)
c = torch.tensor(a_s._untyped(), device=device, dtype=dtype).reshape(a.size())
c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size())
self.assertEqual(a, c)

for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Expand All @@ -255,7 +255,7 @@ def test_set_storage(self, device, dtype):
a_s = a.storage()
b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
self.assertEqual(a, b)
c = torch.tensor([], device=device, dtype=dtype).set_(a_s._untyped()).reshape(a.size())
c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size())
self.assertEqual(a, c)

for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Expand All @@ -267,11 +267,11 @@ def test_set_storage(self, device, dtype):

def _check_storage_meta(self, s, s_check):
self.assertTrue(
isinstance(s, (torch._UntypedStorage, torch._TypedStorage)) and
isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and
isinstance(s_check, type(s)),
(
's and s_check must both be one of _UntypedStorage or '
'_TypedStorage, but got'
's and s_check must both be one of UntypedStorage or '
'TypedStorage, but got'
f' {type(s).__name__} and {type(s_check).__name__}'))

self.assertEqual(s.device.type, 'meta')
Expand All @@ -282,9 +282,9 @@ def _check_storage_meta(self, s, s_check):
with self.assertRaisesRegex(NotImplementedError, r'Not available'):
s[0]

if isinstance(s, torch._TypedStorage):
if isinstance(s, torch.TypedStorage):
self.assertEqual(s.dtype, s_check.dtype)
self._check_storage_meta(s._untyped(), s_check._untyped())
self._check_storage_meta(s.untyped(), s_check.untyped())

@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Expand All @@ -296,8 +296,8 @@ def test_typed_storage_meta(self, device, dtype):
[[1, 2, 3, 4, 5, 6]],
]
for args in args_list:
s_check = torch._TypedStorage(*args, dtype=dtype, device=device)
s = torch._TypedStorage(*args, dtype=dtype, device='meta')
s_check = torch.TypedStorage(*args, dtype=dtype, device=device)
s = torch.TypedStorage(*args, dtype=dtype, device='meta')
self._check_storage_meta(s, s_check)

@onlyNativeDeviceTypes
Expand All @@ -309,8 +309,8 @@ def test_untyped_storage_meta(self, device):
[[1, 2, 3, 4, 5, 6]],
]
for args in args_list:
s_check = torch._UntypedStorage(*args, device=device)
s = torch._UntypedStorage(*args, device='meta')
s_check = torch.UntypedStorage(*args, device=device)
s = torch.UntypedStorage(*args, device='meta')
self._check_storage_meta(s, s_check)

@onlyNativeDeviceTypes
Expand All @@ -326,7 +326,7 @@ def test_storage_meta_from_tensor(self, device, dtype):
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_storage_meta_errors(self, device, dtype):
s0 = torch._TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)

with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s0.cpu()
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_storage_meta_errors(self, device, dtype):
s0._write_file(f, True, True, s0.element_size())

for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
s1 = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)

with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s1.copy_(s0)
Expand Down Expand Up @@ -6444,7 +6444,7 @@ def test_storage_error(self):
torch.storage._LegacyStorage()

for storage_class in torch._storage_classes:
if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
if storage_class in [torch.UntypedStorage, torch.TypedStorage]:
continue

device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
Expand Down Expand Up @@ -6475,9 +6475,9 @@ def test_storage_error(self):
s = storage_class()

with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
storage_class(0, wrap_storage=s._untyped())
storage_class(0, wrap_storage=s.untyped())

with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"):
storage_class(wrap_storage=s)

if torch.cuda.is_available():
Expand All @@ -6493,40 +6493,40 @@ def test_storage_error(self):
s_other_device = s.cuda()

with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
storage_class(wrap_storage=s_other_device._untyped())
storage_class(wrap_storage=s_other_device.untyped())

# _TypedStorage constructor errors
# TypedStorage constructor errors
with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype)

with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
torch._TypedStorage(wrap_storage=s._untyped())
torch.TypedStorage(wrap_storage=s.untyped())

with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
torch.TypedStorage(wrap_storage=s.untyped(), dtype=0)

with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device)

with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
torch._TypedStorage(wrap_storage=s, dtype=dtype)
with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"):
torch.TypedStorage(wrap_storage=s, dtype=dtype)

with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
torch._TypedStorage(dtype=dtype, device='xla')
torch.TypedStorage(dtype=dtype, device='xla')

if torch.cuda.is_available():
if storage_class in quantized_storages:
with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
torch._TypedStorage(dtype=dtype, device='cuda')
torch.TypedStorage(dtype=dtype, device='cuda')

with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device)

with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
torch._TypedStorage(0, 0, dtype=dtype, device=device)
torch.TypedStorage(0, 0, dtype=dtype, device=device)

if isinstance(s, torch._TypedStorage):
s_other = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
if isinstance(s, torch.TypedStorage):
s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)

with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
s.fill_(s_other)
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an _UntypedStorage, but got type ", storage_scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor {
pybind11::gil_scoped_release no_gil;
Expand All @@ -1162,7 +1162,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an _UntypedStorage, but got type ", storage_scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self,
Storage source,
Expand Down
4 changes: 2 additions & 2 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,8 @@ def gen_pyi(
"def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
],
"set_": [
"def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
"def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...",
],
"split": [
"def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ from typing_extensions import Literal
from torch._six import inf

from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt
from torch.storage import _TypedStorage
from torch.storage import TypedStorage

import builtins

Expand Down
10 changes: 5 additions & 5 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
'_TypedStorage',
'TypedStorage', 'UntypedStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
'lobpcg', 'use_deterministic_algorithms',
Expand Down Expand Up @@ -656,10 +656,10 @@ def is_warn_always_enabled():
################################################################################

from ._tensor import Tensor
from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage

# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage._TypedStorage directly.
# dtype, use torch.storage.TypedStorage directly.

class ByteStorage(_LegacyStorage):
@classproperty
Expand Down Expand Up @@ -747,11 +747,11 @@ def dtype(self):
return torch.quint2x4

_storage_classes = {
_UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
_TypedStorage
TypedStorage
}

# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
Expand Down
10 changes: 5 additions & 5 deletions torch/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def _save_storages(importer, obj):
importers = sys_importer

def persistent_id(obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
Expand Down Expand Up @@ -66,11 +66,11 @@ def persistent_load(saved_id):

if typename == "storage":
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
# stop wrapping with TypedStorage
storage = serialized_storages[data[0]]
dtype = serialized_dtypes[data[0]]
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
)

if typename == "reduce_deploy":
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch._prims_common as utils
import torch.library
from torch import _TypedStorage, Tensor
from torch import Tensor, TypedStorage
from torch._C import _get_default_device
from torch._prims_common import (
check,
Expand Down
Loading

0 comments on commit 14d0296

Please sign in to comment.