Skip to content

Commit df55871

Browse files
authored
Update docstrings (#232)
1 parent 2a41f00 commit df55871

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class ReplayBuffer:
114114
samples.
115115
prefetch (int, optional): number of next batches to be prefetched
116116
using multithreading.
117+
storage (Storage, optional): the storage to be used. If none is provided,
118+
a ListStorage will be instantiated.
117119
"""
118120

119121
def __init__(
@@ -300,6 +302,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
300302
samples.
301303
prefetch (int, optional): number of next batches to be prefetched
302304
using multithreading.
305+
storage (Storage, optional): the storage to be used. If none is provided,
306+
a ListStorage will be instantiated.
303307
"""
304308

305309
def __init__(
@@ -583,6 +587,8 @@ class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer):
583587
the rb samples. Default is `False`.
584588
prefetch (int, optional): number of next batches to be prefetched
585589
using multithreading.
590+
storage (Storage, optional): the storage to be used. If none is provided,
591+
a ListStorage will be instantiated.
586592
"""
587593

588594
def __init__(

torchrl/data/tensordict/memmap.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@
2121
torch_to_numpy_dtype_dict,
2222
)
2323

24-
# try:
25-
# from torch.utils._python_dispatch import enable_torch_dispatch_mode
26-
# from torch._subclasses.fake_tensor import (
27-
# FakeTensor,
28-
# FakeTensorMode,
29-
# FakeTensorConverter,
30-
# DynamicOutputShapeException,
31-
# )
32-
# _has_fake = True
33-
# except:
34-
_has_fake = False
35-
3624
MEMMAP_HANDLED_FN = {}
3725

3826
__all__ = ["MemmapTensor", "set_transfer_ownership"]
@@ -74,15 +62,23 @@ class MemmapTensor(object):
7462
Supports (almost) all tensor operations.
7563
7664
Args:
77-
elem (torch.Tensor or MemmapTensor): TODO // Tensor to be stored on physical
78-
storage. If MemmapTensor, a new MemmapTensor is created and the
79-
same data is stored in it.
80-
transfer_ownership: bool: affects the ownership after serialization:
65+
*tensor_or_size (torch.Tensor, MemmapTensor, torch.Size or sequence of integers):
66+
If a size is provided (with a sequence of integers, a torch.Size object
67+
or a list/tuple of integers) it indicates the size of the MemmapTensor created.
68+
If a te is provided, its content will be stored on physical storage.
69+
If MemmapTensor, a new MemmapTensor is created and the same data is stored in it.
70+
device (torch.device or equivalent, optional): device where the loaded
71+
tensor will be sent. This should not be used with MemmapTensors
72+
created from torch.Tensor objects. Default is "cpu".
73+
dtype (torch.dtype, optional): dtype of the loaded tensor.
74+
This should not be used with MemmapTensors created from torch.Tensor
75+
objects. Default is `torch.get_default_dtype()`.
76+
transfer_ownership (bool, optional): affects the ownership after serialization:
8177
if True, the current process looses ownership immediately after
8278
serialization. If False, the current process keeps the ownership
8379
of the temporary file.
8480
Default: False.
85-
prefix: TODO
81+
prefix (str or path, optional): prefix of the file location.
8682
8783
Examples:
8884
>>> x = torch.ones(3,4)
@@ -153,9 +149,6 @@ def __init__(
153149
device = device if device is not None else torch.device("cpu")
154150
dtype = dtype if dtype is not None else torch.get_default_dtype()
155151
self._init_shape(shape, device, dtype, transfer_ownership)
156-
if _has_fake:
157-
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
158-
self._fake = torch.zeros(self.shape, device=self.device)
159152

160153
def _init_shape(
161154
self,
@@ -272,10 +265,7 @@ def _load_item(
272265
and len(idx) == 1
273266
and not (isinstance(idx, torch.Tensor) and idx.dtype is torch.bool)
274267
): # and isinstance(idx, torch.Tensor) and len(idx) == 1:
275-
if _has_fake:
276-
size = self._fake[idx].shape
277-
else:
278-
size = _getitem_batch_size(self.shape, idx)
268+
size = _getitem_batch_size(self.shape, idx)
279269
out = out.view(size)
280270
return out
281271

0 commit comments

Comments
 (0)