7
7
8
8
import functools
9
9
import tempfile
10
+ from math import prod
10
11
from typing import Any , Callable , List , Optional , Tuple , Union
11
12
12
13
import numpy as np
@@ -97,7 +98,65 @@ class MemmapTensor(object):
97
98
def __init__ (
98
99
self ,
99
100
elem : Union [torch .Tensor , MemmapTensor ],
101
+ * size : int ,
102
+ device : DEVICE_TYPING = None ,
103
+ dtype : torch .dtype = None ,
100
104
transfer_ownership : bool = False ,
105
+ ):
106
+ self .idx = None
107
+ self ._memmap_array = None
108
+ self .file = tempfile .NamedTemporaryFile ()
109
+ self .filename = self .file .name
110
+
111
+ if isinstance (elem , (torch .Tensor , MemmapTensor , np .ndarray )):
112
+ if device is not None :
113
+ raise TypeError (
114
+ "device cannot be passed when creating a MemmapTensor from a tensor"
115
+ )
116
+ if dtype is not None :
117
+ raise TypeError (
118
+ "dtype cannot be passed when creating a MemmapTensor from a tensor"
119
+ )
120
+ return self ._init_tensor (elem , transfer_ownership )
121
+ else :
122
+ if not isinstance (elem , int ) and size :
123
+ raise TypeError (
124
+ "Valid init methods for MemmapTensor are: "
125
+ "\n - MemmapTensor(tensor, ...)"
126
+ "\n - MemmapTensor(size, ...)"
127
+ "\n - MemmapTensor(*size, ...)"
128
+ )
129
+ shape = (
130
+ torch .Size ([elem ] + list (size ))
131
+ if isinstance (elem , int )
132
+ else torch .Size (elem )
133
+ )
134
+ device = device if device is not None else torch .device ("cpu" )
135
+ dtype = dtype if dtype is not None else torch .get_default_dtype ()
136
+ return self ._init_shape (shape , device , dtype , transfer_ownership )
137
+
138
+ def _init_shape (
139
+ self ,
140
+ shape : torch .Size ,
141
+ device : DEVICE_TYPING ,
142
+ dtype : torch .dtype ,
143
+ transfer_ownership : bool ,
144
+ ):
145
+ self ._device = device
146
+ self ._shape = shape
147
+ self .transfer_ownership = transfer_ownership
148
+ self .np_shape = tuple (self ._shape )
149
+ self ._dtype = dtype
150
+ self ._ndim = len (shape )
151
+ self ._numel = prod (shape )
152
+ self .mode = "r+"
153
+ self ._has_ownership = True
154
+
155
+ self ._tensor_dir = torch .zeros (1 , device = device , dtype = dtype ).__dir__ ()
156
+ self ._save_item (shape )
157
+
158
+ def _init_tensor (
159
+ self , elem : Union [torch .Tensor , MemmapTensor ], transfer_ownership : bool
101
160
):
102
161
if not isinstance (elem , (torch .Tensor , MemmapTensor )):
103
162
raise TypeError (
@@ -110,10 +169,6 @@ def __init__(
110
169
"Consider calling tensor.detach() first."
111
170
)
112
171
113
- self .idx = None
114
- self ._memmap_array = None
115
- self .file = tempfile .NamedTemporaryFile ()
116
- self .filename = self .file .name
117
172
self ._device = elem .device
118
173
self ._shape = elem .shape
119
174
self .transfer_ownership = transfer_ownership
@@ -153,11 +208,15 @@ def _set_memmap_array(self, value: np.memmap) -> None:
153
208
154
209
def _save_item (
155
210
self ,
156
- value : Union [torch .Tensor , MemmapTensor , np .ndarray ],
211
+ value : Union [torch .Tensor , torch . Size , MemmapTensor , np .ndarray ],
157
212
idx : Optional [int ] = None ,
158
213
):
159
214
if isinstance (value , (torch .Tensor ,)):
160
215
np_array = value .cpu ().numpy ()
216
+ elif isinstance (value , torch .Size ):
217
+ # create the memmap array on disk
218
+ _ = self .memmap_array
219
+ return
161
220
else :
162
221
np_array = value
163
222
memmap_array = self .memmap_array
0 commit comments