forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.py
378 lines (295 loc) · 11.4 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import torch
import warnings
from . import _tensor_str
from ._utils import _type, _cuda, _range, _rebuild_tensor
import sys
class _TensorBase(object):
#: bool: True if this is a CUDA tensor
is_cuda = False
is_sparse = False
# NB: This implementation is CPU only; see THPTensor_(new) for the
# CUDA case, which handles constructing the tensor on the same GPU
# as this tensor.
def new(self, *args, **kwargs):
"""Constructs a new tensor of the same data type."""
return self.__class__(*args, **kwargs)
def type_as(self, tensor):
"""Returns this tensor cast to the type of the given tensor.
This is a no-op if the tensor is already of the correct type. This is
equivalent to::
self.type(tensor.type())
Params:
tensor (Tensor): the tensor which has the desired type
"""
return self.type(tensor.type())
def cpu(self):
"""Returns a CPU copy of this tensor if it's not already on the CPU"""
return self.type(getattr(torch, self.__class__.__name__))
def double(self):
"""Casts this tensor to double type"""
return self.type(type(self).__module__ + '.DoubleTensor')
def float(self):
"""Casts this tensor to float type"""
return self.type(type(self).__module__ + '.FloatTensor')
def half(self):
"""Casts this tensor to half-precision float type"""
return self.type(type(self).__module__ + '.HalfTensor')
def long(self):
"""Casts this tensor to long type"""
return self.type(type(self).__module__ + '.LongTensor')
def int(self):
"""Casts this tensor to int type"""
return self.type(type(self).__module__ + '.IntTensor')
def short(self):
"""Casts this tensor to short type"""
return self.type(type(self).__module__ + '.ShortTensor')
def char(self):
"""Casts this tensor to char type"""
return self.type(type(self).__module__ + '.CharTensor')
def byte(self):
"""Casts this tensor to byte type"""
return self.type(type(self).__module__ + '.ByteTensor')
def is_pinned(self):
"""Returns true if this tensor resides in pinned memory"""
storage = self.storage()
return storage.is_pinned() if storage else False
def pin_memory(self):
"""Copies the tensor to pinned memory, if it's not already pinned."""
if self.is_cuda:
raise TypeError("cannot pin '{0}' only CPU memory can be pinned"
.format(self.type()))
storage = self.storage()
if storage is None:
storage = (self.storage_type())()
return type(self)().set_(storage.pin_memory()).view_as(self)
def share_memory_(self):
"""Moves the underlying storage to shared memory.
This is a no-op if the underlying storage is already in shared memory
and for CUDA tensors. Tensors in shared memory cannot be resized.
"""
self.storage().share_memory_()
return self
def is_shared(self):
"""Checks if tensor is in shared memory.
This is always ``True`` for CUDA tensors.
"""
return self.storage().is_shared()
@property
def shape(self):
"""Alias for .size()
Returns a torch.Size object, containing the dimensions of the tensor
"""
return self.size()
def __deepcopy__(self, _memo):
memo = _memo.setdefault('torch', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.storage().__deepcopy__(_memo)
new_tensor = self.new()
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
memo[self._cdata] = new_tensor
return new_tensor
def __reduce__(self):
# NOTE: _rebuild_tensor does not call __setstate__
args = self.__getstate__()
return (_rebuild_tensor, args)
def __getstate__(self):
return (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride())
def __setstate__(self, state):
self.set_(*state)
def __repr__(self):
return str(self)
def __str__(self):
# All strings are unicode in Python 3, while we have to encode unicode
# strings in Python2. If we can't, let python decide the best
# characters to replace unicode characters with.
if sys.version_info > (3,):
return _tensor_str._str(self)
else:
if hasattr(sys.stdout, 'encoding'):
return _tensor_str._str(self).encode(
sys.stdout.encoding or 'UTF-8', 'replace')
else:
return _tensor_str._str(self).encode('UTF-8', 'replace')
def __bool__(self):
if self.numel() == 0:
return False
raise RuntimeError("bool value of non-empty " + torch.typename(self) +
" objects is ambiguous")
__nonzero__ = __bool__
def __iter__(self):
if self.nelement() > 0:
return iter(map(lambda i: self.select(0, i), _range(self.size(0))))
else:
return iter([])
def split(self, split_size, dim=0):
"""Splits this tensor into a tuple of tensors.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def chunk(self, n_chunks, dim=0):
"""Splits this tensor into a tuple of tensors.
See :func:`torch.chunk`.
"""
return torch.chunk(self, n_chunks, dim)
def matmul(self, other):
"""Matrix product of two tensors.
See :func:`torch.matmul`."""
return torch.matmul(self, other)
def tolist(self):
"""Returns a nested list represenation of this tensor."""
dim = self.dim()
if dim == 1:
return [v for v in self]
elif dim > 0:
return [subt.tolist() for subt in self]
return []
def view_as(self, tensor):
"""Returns this tensor viewed as the size as the specified tensor.
This is equivalent to::
self.view(tensor.size())
"""
return self.view(tensor.size())
def permute(self, *dims):
"""Permute the dimensions of this tensor.
Args:
*dims (int...): The desired ordering of dimensions
Example:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])
"""
perm = list(dims)
tensor = self
n_dims = tensor.dim()
assert len(perm) == n_dims, 'Invalid permutation'
for i, p in enumerate(perm):
if p != i and p != -1:
j = i
while True:
assert 0 <= perm[j] and perm[j] < n_dims, 'Invalid permutation'
tensor = tensor.transpose(j, perm[j])
perm[j], j = -1, perm[j]
if perm[j] == i:
break
perm[j] = -1
return tensor
def expand_as(self, tensor):
"""Expands this tensor to the size of the specified tensor.
This is equivalent to::
self.expand(tensor.size())
"""
return self.expand(tensor.size())
def repeat(self, *sizes):
"""Repeats this tensor along the specified dimensions.
Unlike :meth:`expand`, this function copies the tensor's data.
Args:
*sizes (torch.Size or int...): The number of times to repeat this
tensor along each dimension
Example:
>>> x = torch.Tensor([1, 2, 3])
>>> x.repeat(4, 2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size 4x6]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
"""
# If args == (torch.Size,), then we need to unpack the tuple
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
sizes = sizes[0]
repeats = list(sizes)
result = self.new()
src = self.contiguous()
if len(repeats) < src.dim():
raise ValueError('Number of dimensions of repeat dims can not be '
'smaller than number of dimensions of tensor')
xtensor = src.new().set_(src)
xsize = list(xtensor.size())
for i in _range(len(repeats) - src.dim()):
xsize = [1] + xsize
size = torch.Size([a * b for a, b in zip(xsize, repeats)])
xtensor.resize_(torch.Size(xsize))
result.resize_(size)
urtensor = result.new(result)
for i in _range(xtensor.dim()):
urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))
for i in _range(urtensor.dim() - xtensor.dim()):
xsize = [1] + xsize
xtensor.resize_(torch.Size(xsize))
xxtensor = xtensor.expand_as(urtensor)
urtensor.copy_(xxtensor)
return result
def masked_copy_(self, *args, **kwargs):
warnings.warn("masked_copy_ is deprecated and renamed to masked_scatter_, and will be removed in v0.3")
return self.masked_scatter_(*args, **kwargs)
# TODO: add tests for operators
def __add__(self, other):
return self.add(other)
__radd__ = __add__
def __iadd__(self, other):
return self.add_(other)
def __sub__(self, other):
return self.sub(other)
def __rsub__(self, other):
return self.new().resize_as_(self).fill_(other).add_(-1, self)
def __isub__(self, other):
return self.sub_(other)
def __mul__(self, other):
return self.mul(other)
__rmul__ = __mul__
def __imul__(self, other):
return self.mul_(other)
def __matmul__(self, other):
if not torch.is_tensor(other):
return NotImplemented
return self.matmul(other)
def __pow__(self, other):
return self.pow(other)
def __ipow__(self, other):
return self.pow_(other)
def __div__(self, other):
return self.div(other)
__truediv__ = __div__
def __rdiv__(self, other):
return self.new().resize_as_(self).fill_(other).div_(self)
__rtruediv__ = __rdiv__
def __idiv__(self, other):
return self.div_(other)
__itruediv__ = __idiv__
def __mod__(self, other):
return self.remainder(other)
def __neg__(self):
return self.neg()
def __eq__(self, other):
return self.eq(other)
def __ne__(self, other):
return self.ne(other)
def __lt__(self, other):
return self.lt(other)
def __le__(self, other):
return self.le(other)
def __gt__(self, other):
return self.gt(other)
def __ge__(self, other):
return self.ge(other)
# TODO: add native add or and xor in the libs
def __invert__(self):
if type(self).__name__ != 'ByteTensor':
raise RuntimeError('logical operations are supported on ByteTensors only')
return (1 - self)
def __hash__(self):
return id(self)
# provide user guidance when they inavertently call autograd properties on a Tensor
@property
def data(self):
raise RuntimeError('cannot call .data on a torch.Tensor: did you intend to use autograd.Variable?')
_TensorBase.type = _type
_TensorBase.cuda = _cuda