|
21 | 21 | torch_to_numpy_dtype_dict,
|
22 | 22 | )
|
23 | 23 |
|
24 |
| -# try: |
25 |
| -# from torch.utils._python_dispatch import enable_torch_dispatch_mode |
26 |
| -# from torch._subclasses.fake_tensor import ( |
27 |
| -# FakeTensor, |
28 |
| -# FakeTensorMode, |
29 |
| -# FakeTensorConverter, |
30 |
| -# DynamicOutputShapeException, |
31 |
| -# ) |
32 |
| -# _has_fake = True |
33 |
| -# except: |
34 |
| -_has_fake = False |
35 |
| - |
36 | 24 | MEMMAP_HANDLED_FN = {}
|
37 | 25 |
|
38 | 26 | __all__ = ["MemmapTensor", "set_transfer_ownership"]
|
@@ -74,15 +62,23 @@ class MemmapTensor(object):
|
74 | 62 | Supports (almost) all tensor operations.
|
75 | 63 |
|
76 | 64 | Args:
|
77 |
| - elem (torch.Tensor or MemmapTensor): TODO // Tensor to be stored on physical |
78 |
| - storage. If MemmapTensor, a new MemmapTensor is created and the |
79 |
| - same data is stored in it. |
80 |
| - transfer_ownership: bool: affects the ownership after serialization: |
| 65 | + *tensor_or_size (torch.Tensor, MemmapTensor, torch.Size or sequence of integers): |
| 66 | + If a size is provided (with a sequence of integers, a torch.Size object |
| 67 | + or a list/tuple of integers) it indicates the size of the MemmapTensor created. |
| 68 | + If a te is provided, its content will be stored on physical storage. |
| 69 | + If MemmapTensor, a new MemmapTensor is created and the same data is stored in it. |
| 70 | + device (torch.device or equivalent, optional): device where the loaded |
| 71 | + tensor will be sent. This should not be used with MemmapTensors |
| 72 | + created from torch.Tensor objects. Default is "cpu". |
| 73 | + dtype (torch.dtype, optional): dtype of the loaded tensor. |
| 74 | + This should not be used with MemmapTensors created from torch.Tensor |
| 75 | + objects. Default is `torch.get_default_dtype()`. |
| 76 | + transfer_ownership (bool, optional): affects the ownership after serialization: |
81 | 77 | if True, the current process looses ownership immediately after
|
82 | 78 | serialization. If False, the current process keeps the ownership
|
83 | 79 | of the temporary file.
|
84 | 80 | Default: False.
|
85 |
| - prefix: TODO |
| 81 | + prefix (str or path, optional): prefix of the file location. |
86 | 82 |
|
87 | 83 | Examples:
|
88 | 84 | >>> x = torch.ones(3,4)
|
@@ -153,9 +149,6 @@ def __init__(
|
153 | 149 | device = device if device is not None else torch.device("cpu")
|
154 | 150 | dtype = dtype if dtype is not None else torch.get_default_dtype()
|
155 | 151 | self._init_shape(shape, device, dtype, transfer_ownership)
|
156 |
| - if _has_fake: |
157 |
| - with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): |
158 |
| - self._fake = torch.zeros(self.shape, device=self.device) |
159 | 152 |
|
160 | 153 | def _init_shape(
|
161 | 154 | self,
|
@@ -272,10 +265,7 @@ def _load_item(
|
272 | 265 | and len(idx) == 1
|
273 | 266 | and not (isinstance(idx, torch.Tensor) and idx.dtype is torch.bool)
|
274 | 267 | ): # and isinstance(idx, torch.Tensor) and len(idx) == 1:
|
275 |
| - if _has_fake: |
276 |
| - size = self._fake[idx].shape |
277 |
| - else: |
278 |
| - size = _getitem_batch_size(self.shape, idx) |
| 268 | + size = _getitem_batch_size(self.shape, idx) |
279 | 269 | out = out.view(size)
|
280 | 270 | return out
|
281 | 271 |
|
|
0 commit comments