Skip to content

Commit 44eae43

Browse files
authored
Feature: Create MemmapTensors with shape, device and dtype (#219)
1 parent 01e00b5 commit 44eae43

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

torchrl/data/tensordict/memmap.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import functools
99
import tempfile
10+
from math import prod
1011
from typing import Any, Callable, List, Optional, Tuple, Union
1112

1213
import numpy as np
@@ -97,7 +98,65 @@ class MemmapTensor(object):
9798
def __init__(
9899
self,
99100
elem: Union[torch.Tensor, MemmapTensor],
101+
*size: int,
102+
device: DEVICE_TYPING = None,
103+
dtype: torch.dtype = None,
100104
transfer_ownership: bool = False,
105+
):
106+
self.idx = None
107+
self._memmap_array = None
108+
self.file = tempfile.NamedTemporaryFile()
109+
self.filename = self.file.name
110+
111+
if isinstance(elem, (torch.Tensor, MemmapTensor, np.ndarray)):
112+
if device is not None:
113+
raise TypeError(
114+
"device cannot be passed when creating a MemmapTensor from a tensor"
115+
)
116+
if dtype is not None:
117+
raise TypeError(
118+
"dtype cannot be passed when creating a MemmapTensor from a tensor"
119+
)
120+
return self._init_tensor(elem, transfer_ownership)
121+
else:
122+
if not isinstance(elem, int) and size:
123+
raise TypeError(
124+
"Valid init methods for MemmapTensor are: "
125+
"\n- MemmapTensor(tensor, ...)"
126+
"\n- MemmapTensor(size, ...)"
127+
"\n- MemmapTensor(*size, ...)"
128+
)
129+
shape = (
130+
torch.Size([elem] + list(size))
131+
if isinstance(elem, int)
132+
else torch.Size(elem)
133+
)
134+
device = device if device is not None else torch.device("cpu")
135+
dtype = dtype if dtype is not None else torch.get_default_dtype()
136+
return self._init_shape(shape, device, dtype, transfer_ownership)
137+
138+
def _init_shape(
139+
self,
140+
shape: torch.Size,
141+
device: DEVICE_TYPING,
142+
dtype: torch.dtype,
143+
transfer_ownership: bool,
144+
):
145+
self._device = device
146+
self._shape = shape
147+
self.transfer_ownership = transfer_ownership
148+
self.np_shape = tuple(self._shape)
149+
self._dtype = dtype
150+
self._ndim = len(shape)
151+
self._numel = prod(shape)
152+
self.mode = "r+"
153+
self._has_ownership = True
154+
155+
self._tensor_dir = torch.zeros(1, device=device, dtype=dtype).__dir__()
156+
self._save_item(shape)
157+
158+
def _init_tensor(
159+
self, elem: Union[torch.Tensor, MemmapTensor], transfer_ownership: bool
101160
):
102161
if not isinstance(elem, (torch.Tensor, MemmapTensor)):
103162
raise TypeError(
@@ -110,10 +169,6 @@ def __init__(
110169
"Consider calling tensor.detach() first."
111170
)
112171

113-
self.idx = None
114-
self._memmap_array = None
115-
self.file = tempfile.NamedTemporaryFile()
116-
self.filename = self.file.name
117172
self._device = elem.device
118173
self._shape = elem.shape
119174
self.transfer_ownership = transfer_ownership
@@ -153,11 +208,15 @@ def _set_memmap_array(self, value: np.memmap) -> None:
153208

154209
def _save_item(
155210
self,
156-
value: Union[torch.Tensor, MemmapTensor, np.ndarray],
211+
value: Union[torch.Tensor, torch.Size, MemmapTensor, np.ndarray],
157212
idx: Optional[int] = None,
158213
):
159214
if isinstance(value, (torch.Tensor,)):
160215
np_array = value.cpu().numpy()
216+
elif isinstance(value, torch.Size):
217+
# create the memmap array on disk
218+
_ = self.memmap_array
219+
return
161220
else:
162221
np_array = value
163222
memmap_array = self.memmap_array

torchrl/trainers/trainers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def train(self):
370370

371371
if self.collected_frames > self.total_frames:
372372
break
373+
self.collector.shutdown()
373374
self.save_trainer(force_save=True)
374375

375376
def __del__(self):

0 commit comments

Comments
 (0)