forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathResize.h
181 lines (158 loc) · 6.36 KB
/
Resize.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/TensorUtils.h>
#include <c10/core/CPUAllocator.h>
namespace at { namespace native {
// TODO: make all operations that resize given outputs use this function
// for consistency and maintainability.
// Some operations like `cat` might not be able to make the use of
// resize_output directly. For more details to understand how it works in `cat`,
// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
// Resizes outputs
// Functions accepting output tensors, like with the "out" kwarg, should
// call this function to handle resizing their output tensor.
// Issues a warning if the output tensor has one or more elements and
// needs resizing
// NOTE: In the future the warning will become an error
// Returns a bool saying whether or not the resize actually happened or not
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
// Utility for resize_output
// Returns a bool saying resize should happen or not and
// raises a warning if resizing for one or more elements
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
static inline void maybe_resize_storage_cpu(TensorImpl* self, uint64_t new_size) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in cuda/Resize.h)
if (new_size == 0) {
return;
}
const auto new_size_bytes_i =
(new_size + self->storage_offset()) * self->dtype().itemsize();
TORCH_CHECK(!overflows<size_t>(new_size_bytes_i), "Requested storage size (",
new_size_bytes_i, ") cannot be represented as a size_t");
const auto new_size_bytes = static_cast<size_t>(new_size_bytes_i);
const Storage& storage = self->unsafe_storage();
if (!storage) {
auto new_storage = c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
new_size_bytes,
c10::GetCPUAllocator(),
true);
self->set_storage_keep_dtype(std::move(new_storage));
} else if (new_size_bytes > storage.nbytes()) {
resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
}
}
inline TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
c10::optional<IntArrayRef> stride,
bool resize_storage = true) {
if (self->sizes() == size && (!stride || self->strides() == stride)) {
return self;
}
int64_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
storage_size = storage_size_for(size, *stride);
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
}
if (resize_storage) {
maybe_resize_storage_cpu(self, storage_size);
}
return self;
}
static inline void checkInBoundsForStorage(
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset,
const caffe2::TypeMeta data_type,
const Storage& new_storage) {
int64_t storage_size_bytes =
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
int64_t storage_offset_bytes = storage_offset * data_type.itemsize();
if (storage_size_bytes == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return;
}
int64_t new_storage_size_bytes = new_storage.nbytes();
TORCH_CHECK(
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
"setStorage: sizes ",
size,
", strides ",
stride,
","
" storage offset ",
storage_offset,
", and itemsize ",
data_type.itemsize(),
" requiring a storage size of ",
storage_size_bytes + storage_offset_bytes,
" are out of bounds for storage of size ",
new_storage_size_bytes);
}
static inline void checkSetStorage(Tensor& result, Storage storage, int64_t storage_offset,
IntArrayRef size, IntArrayRef stride) {
// FIXME: stride should be optional
if (stride.data()) {
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
") and stride length (", stride.size(), ")");
}
#ifdef DEBUG
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
#endif
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
// function is to set the tensor size to be equal to the size of the storage.
if (!result.storage().is_alias_of(storage)) {
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
TORCH_INTERNAL_ASSERT(storage);
TORCH_INTERNAL_ASSERT(result.storage());
// We used to allow this, but this breaks device caching.
// Let's put an actual error message for this one.
TORCH_CHECK(result.storage().device() == storage.device(),
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
"\" to a storage on different device \"", storage.device(),
"\". This is no longer allowed; the devices must match.");
result.unsafeGetTensorImpl()->set_storage_keep_dtype(storage);
}
// storageOffset
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
}
/**
* Set self's sizes, strides, and storage_offset.
* (size, stride, storage_offset) must be in bounds for self's storage.
*/
inline void setStrided(
const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset) {
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(
size, stride, storage_offset, self_->dtype(), self_->storage());
/* storage offset */
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
self_->set_storage_offset(storage_offset);
/* size and stride */
if (self_->sizes() == size && self_->strides() == stride) {
return;
}
for (auto val : stride) {
TORCH_CHECK(val >= 0,
"as_strided: Negative strides are not supported at the moment, "
"got strides: ", stride);
}
self_->set_sizes_and_strides(size, stride);
}
}}