forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSymInt.cpp
284 lines (258 loc) · 9.88 KB
/
SymInt.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/safe_numerics.h>
#include <functional>
namespace c10 {
// Precondition: data_ has a large negative number that should be
// treated as a constant. It is NOT a valid pointer. In other words,
// SymInt has temporarily violated invariants
// Postcondition: invariants on SymInt are fixed
void SymInt::promote_to_negative() {
auto s =
SymInt(SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(data_)));
// Similar to move operator=, but do NOT release data_
data_ = s.data_;
s.data_ = 0;
}
SymNode SymInt::toSymNode() const {
TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
is_heap_allocated(), "SymInt::toSymNode is_heap_allocated");
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}
SymInt::SymInt(SymNode sin_sp) {
TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
sin_sp->is_int(), "SymInt::SymInt sin_sp->is_int()");
auto ptr = static_cast<uint64_t>(
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
auto rep = (ptr & ~MASK) | IS_SYM;
data_ = static_cast<int64_t>(rep);
}
bool SymInt::has_hint() const {
if (!is_heap_allocated()) {
return true;
}
return toSymNodeImplUnowned()->has_hint();
}
#define DEFINE_BINARY(API, OP, METHOD, RET) \
RET SymInt::API(const SymInt& sci) const { \
if (auto ma = maybe_as_int()) { \
if (auto mb = sci.maybe_as_int()) { \
return RET(OP(*ma, *mb)); \
} else { \
auto b = sci.toSymNode(); \
return RET(b->wrap_int(*ma)->METHOD(b)); \
} \
} else { \
if (auto mb = sci.maybe_as_int()) { \
auto a = toSymNodeImplUnowned(); \
return RET(a->METHOD(a->wrap_int(*mb))); \
} else { \
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNode())); \
} \
} \
}
// clang-format off
DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt)
DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt)
DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool)
DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool)
DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool)
DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool)
DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
DEFINE_BINARY(min, std::min, sym_min, SymInt)
DEFINE_BINARY(max, std::max, sym_max, SymInt)
// clang-format on
SymInt::operator SymFloat() const {
if (auto ma = maybe_as_int()) {
return SymFloat(double(*ma));
} else {
return SymFloat(toSymNodeImplUnowned()->sym_float());
}
}
bool SymInt::is_same(const SymInt& other) const {
if (is_heap_allocated() != other.is_heap_allocated()) {
return false;
}
// Both not heap allocated
if (!is_heap_allocated() && this->operator!=(other)) {
return false;
}
// Both heap allocated
if (is_heap_allocated() &&
toSymNodeImplUnowned() != other.toSymNodeImplUnowned()) {
return false;
}
return true;
}
SymNode SymInt::wrap_node(const SymNode& base) const {
if (auto ma = maybe_as_int()) {
return base->wrap_int(*ma);
} else {
return toSymNode();
}
}
SymInt SymInt::clone() const {
if (auto ma = maybe_as_int()) {
return SymInt(*ma);
} else {
return SymInt(toSymNodeImplUnowned()->clone());
}
}
int64_t SymInt::guard_int(const char* file, int64_t line) const {
if (auto ma = maybe_as_int()) {
return *ma;
} else {
return toSymNodeImplUnowned()->guard_int(file, line);
}
}
bool SymInt::expect_size(const char* file, int64_t line) const {
if (auto ma = maybe_as_int()) {
return *ma >= 0;
} else {
return toSymNodeImplUnowned()->expect_size(file, line);
}
}
SymInt operator-(const SymInt& s) {
if (auto ma = s.maybe_as_int()) {
const auto val = *ma;
// Note: Result of `-std::numeric_limits<decltype(val)>::min()` is undefined
// But on many platforms it equals to self + setting Carry/Overflow flags
// Which in opimized code affects results of `check_range` condition
// Workaround by using ternary that avoids alterning the flags
#if C10_HAS_BUILTIN_OVERFLOW()
std::decay_t<decltype(val)> out = 0;
if (C10_UNLIKELY(__builtin_sub_overflow(out, val, &out))) {
return SymInt(val);
}
return SymInt(out);
#else
constexpr auto val_min = std::numeric_limits<decltype(val)>::min();
return SymInt(val != val_min ? -val : val_min);
#endif
} else {
return SymInt(s.toSymNodeImplUnowned()->neg());
}
}
void SymInt::operator*=(const SymInt& sci) {
*this = *this * sci;
}
void SymInt::operator/=(const SymInt& sci) {
*this = *this / sci;
}
void SymInt::operator+=(const SymInt& sci) {
*this = *this + sci;
}
std::ostream& operator<<(std::ostream& os, const SymInt& s) {
if (s.is_heap_allocated()) {
os << s.toSymNodeImplUnowned()->str();
} else {
os << s.as_int_unchecked();
}
return os;
}
// This template lets us not do a refcount bump when we do an
// identity conversion
template <typename T>
struct Convert {};
template <>
struct Convert<SymInt> {
const SymInt& operator()(const SymInt& a) {
return a;
}
};
template <>
struct Convert<SymFloat> {
SymFloat operator()(const SymInt& a) {
return a;
}
};
#define DEFINE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
RetTy operator%(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) % RetTy(b); \
}; \
RetTy operator%(scalar_t a, const SymInt& b) { \
return RetTy(a) % Convert<RetTy>()(b); \
};
#define DEFINE_SYMINT_OP(scalar_t, RetTy) \
RetTy operator+(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) + RetTy(b); \
}; \
RetTy operator-(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) - RetTy(b); \
}; \
RetTy operator*(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) * RetTy(b); \
}; \
RetTy operator/(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) / RetTy(b); \
}; \
RetTy operator+(scalar_t a, const SymInt& b) { \
return RetTy(a) + Convert<RetTy>()(b); \
}; \
RetTy operator-(scalar_t a, const SymInt& b) { \
return RetTy(a) - Convert<RetTy>()(b); \
}; \
RetTy operator*(scalar_t a, const SymInt& b) { \
return RetTy(a) * Convert<RetTy>()(b); \
}; \
RetTy operator/(scalar_t a, const SymInt& b) { \
return RetTy(a) / Convert<RetTy>()(b); \
}; \
bool operator==(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) == RetTy(b); \
}; \
bool operator!=(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) != RetTy(b); \
}; \
bool operator<(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) < RetTy(b); \
}; \
bool operator<=(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) <= RetTy(b); \
}; \
bool operator>(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) > RetTy(b); \
}; \
bool operator>=(const SymInt& a, scalar_t b) { \
return Convert<RetTy>()(a) >= RetTy(b); \
}; \
bool operator==(scalar_t a, const SymInt& b) { \
return RetTy(a) == Convert<RetTy>()(b); \
}; \
bool operator!=(scalar_t a, const SymInt& b) { \
return RetTy(a) != Convert<RetTy>()(b); \
}; \
bool operator<(scalar_t a, const SymInt& b) { \
return RetTy(a) < Convert<RetTy>()(b); \
}; \
bool operator<=(scalar_t a, const SymInt& b) { \
return RetTy(a) <= Convert<RetTy>()(b); \
}; \
bool operator>(scalar_t a, const SymInt& b) { \
return RetTy(a) > Convert<RetTy>()(b); \
}; \
bool operator>=(scalar_t a, const SymInt& b) { \
return RetTy(a) >= Convert<RetTy>()(b); \
};
DEFINE_SYMINT_OP_INTONLY(int64_t, SymInt)
DEFINE_SYMINT_OP_INTONLY(int32_t, SymInt)
DEFINE_SYMINT_OP_INTONLY(uint64_t, SymInt)
DEFINE_SYMINT_OP_INTONLY(uint32_t, SymInt)
DEFINE_SYMINT_OP(int64_t, SymInt)
DEFINE_SYMINT_OP(int32_t, SymInt) // make sure constants work
DEFINE_SYMINT_OP(uint64_t, SymInt)
DEFINE_SYMINT_OP(uint32_t, SymInt)
DEFINE_SYMINT_OP(double, SymFloat)
DEFINE_SYMINT_OP(float, SymFloat) // just for completeness
#if defined(__APPLE__)
DEFINE_SYMINT_OP_INTONLY(size_t, SymInt) // needed for osx
DEFINE_SYMINT_OP(size_t, SymInt) // needed for osx
#endif
} // namespace c10