Skip to content

Commit c7edcd6

Browse files
Revert "Don't introduce new overload for SymInt (pytorch#83628)"
This reverts commit 9790d90. Reverted pytorch#83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
1 parent 38e5e4a commit c7edcd6

File tree

81 files changed

+715
-752
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+715
-752
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a668569f7f9b7ecd946cf2551d30d482799d597d
1+
9b2f7929c2dae841888a836449c25b04c8cf4045

aten/src/ATen/BatchingRegistrations.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
186186
}
187187

188188
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
189-
// TODO: properly support this
190-
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
189+
return self.expand(asIntArrayRefSlow(psize), implicit);
191190
}
192191

193192
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
@@ -470,8 +469,7 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
470469
}
471470

472471
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
473-
// TODO: properly support this
474-
return view_batching_rule(self, asIntArrayRefSlow(size));
472+
return self.view(asIntArrayRefSlow(size));
475473
}
476474

477475
Tensor view_as_complex_batching_rule(const Tensor& self) {
@@ -1011,7 +1009,6 @@ Tensor new_empty_symint_batching_rule(
10111009
c10::optional<Layout> layout,
10121010
c10::optional<Device> device,
10131011
c10::optional<bool> pin_memory) {
1014-
// TODO: properly support this
10151012
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
10161013
}
10171014

@@ -1112,7 +1109,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11121109
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
11131110
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
11141111
m.impl("diagonal", diagonal_batching_rule);
1115-
m.impl("expand", expand_symint_batching_rule);
1112+
m.impl("expand", expand_batching_rule);
1113+
m.impl("expand.SymInt", expand_symint_batching_rule);
11161114
m.impl("expand_as", native::expand_as); // composite wrt autograd
11171115
m.impl("movedim.intlist", movedim_batching_rule);
11181116
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@@ -1140,7 +1138,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11401138
m.impl("unbind.int", unbind_batching_rule);
11411139
m.impl("unfold", unfold_batching_rule);
11421140
m.impl("unsqueeze", unsqueeze_batching_rule);
1143-
m.impl("view", view_symint_batching_rule);
1141+
m.impl("view", view_batching_rule);
1142+
m.impl("view.SymInt", view_symint_batching_rule);
11441143
m.impl("view_as", native::view_as); // composite wrt autograd
11451144

11461145
// clamp operations
@@ -1278,7 +1277,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
12781277
m.impl("diagonal_backward", diagonal_backward_batching_rule);
12791278

12801279
// Tensor.new_* operators
1281-
m.impl("new_empty", new_empty_symint_batching_rule);
1280+
m.impl("new_empty", new_empty_batching_rule);
1281+
m.impl("new_empty.SymInt", new_empty_symint_batching_rule);
12821282
m.impl("new_empty_strided", new_empty_strided_batching_rule);
12831283
m.impl("new_zeros", new_zeros_batching_rule);
12841284

aten/src/ATen/FunctionalInverses.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso
137137
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
138138
}
139139

140-
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
141-
return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
140+
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) {
141+
return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views);
142+
}
143+
144+
Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) {
145+
return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views);
142146
}
143147

144148
Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
@@ -287,15 +291,22 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten
287291
return base.select_scatter(mutated_view, dim, mutated_view_idx);
288292
}
289293

290-
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
294+
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
295+
if (reapply_views) {
296+
return mutated_view.view(base.sizes());
297+
} else {
298+
return at::view_copy(mutated_view, base.sizes());
299+
}
300+
}
301+
302+
Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
291303
if (reapply_views) {
292304
return mutated_view.view_symint(base.sym_sizes());
293305
} else {
294306
return at::view_copy_symint(mutated_view, base.sym_sizes());
295307
}
296308
}
297309

298-
299310
Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
300311
if (reapply_views) {
301312
return mutated_view.view(base.scalar_type());

aten/src/ATen/core/NamedRegistrations.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
179179
m.impl("exp.out", CppFunction::makeFallthrough());
180180
m.impl("exp_", CppFunction::makeFallthrough());
181181
m.impl("expand", CppFunction::makeFallthrough());
182+
m.impl("expand.SymInt", CppFunction::makeFallthrough());
182183
m.impl("expm1", CppFunction::makeFallthrough());
183184
m.impl("expm1.out", CppFunction::makeFallthrough());
184185
m.impl("expm1_", CppFunction::makeFallthrough());

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,7 @@ namespace impl {
353353
template<bool AllowDeprecatedTypes>
354354
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
355355
static std::vector<c10::SymInt> call(IValue& v) {
356-
if (v.isIntList()) {
357-
std::vector<c10::SymInt> r;
358-
auto src = v.toIntList();
359-
std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
360-
return r;
361-
} else {
362-
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
363-
}
356+
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
364357
}
365358
};
366359
template<class T, bool AllowDeprecatedTypes>

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
3535

3636
namespace {
3737
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
38-
// TODO: figure out if we can just directly save real schema at def time
39-
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def.cloneWithRealTypes(), inferred);
38+
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
4039
if (schema_difference.has_value()) {
4140
TORCH_CHECK(false,
4241
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"

aten/src/ATen/core/dynamic_type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ TypePtr DynamicType::fallback() const {
231231
return BoolType::get();
232232
case Tag::Int:
233233
return IntType::get();
234+
case Tag::SymInt:
235+
return SymIntType::get();
234236
case Tag::Float:
235237
return FloatType::get();
236238
case Tag::Complex:
@@ -324,6 +326,8 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
324326
return DynamicTypeTrait<ComplexType>::getBaseType();
325327
case Tag::Int:
326328
return DynamicTypeTrait<IntType>::getBaseType();
329+
case Tag::SymInt:
330+
return DynamicTypeTrait<SymIntType>::getBaseType();
327331
case Tag::Bool:
328332
return DynamicTypeTrait<BoolType>::getBaseType();
329333
case Tag::String:

aten/src/ATen/core/dynamic_type.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
1616

1717
constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
1818
constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
19+
constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23);
1920
constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
2021
constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
2122
constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
@@ -28,6 +29,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
2829
_(Bool, DYNAMIC_TYPE_BIT(2), 1) \
2930
_(Int, kDynamicIntTypeBit, 1) \
3031
_(Float, kDynamicFloatTypeBit, 1) \
32+
_(SymInt, kDynamicSymIntTypeBit, 1) \
3133
_(Complex, kDynamicComplexTypeBit, 1) \
3234
_(Number, \
3335
(kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \
@@ -61,7 +63,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
6163
#define FORALL_DYNAMIC_TYPES_FAKE(_) \
6264
_(ScalarType, kDynamicIntTypeBit, 1) \
6365
_(Layout, kDynamicIntTypeBit, 1) \
64-
_(SymInt, kDynamicIntTypeBit, 1) \
6566
_(MemoryFormat, kDynamicIntTypeBit, 1)
6667

6768
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;

aten/src/ATen/core/function_schema.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,6 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
1717
}
1818
}
1919

20-
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
21-
auto cloneWithRealTypes = [](const Argument& a) {
22-
return a.cloneWithType(a.real_type());
23-
};
24-
std::vector<Argument> new_arguments, new_returns;
25-
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
26-
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
27-
return FunctionSchema(
28-
name(),
29-
overload_name(),
30-
std::move(new_arguments),
31-
std::move(new_returns),
32-
is_vararg(),
33-
is_varret());
34-
}
35-
3620
bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const {
3721
if (!lhs || !rhs) {
3822
return false;

aten/src/ATen/core/function_schema.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Argument {
4444
c10::optional<AliasInfo> alias_info = c10::nullopt)
4545
: name_(std::move(name)),
4646
type_(fake_type ? std::move(fake_type) : TensorType::get()),
47-
real_type_(real_type ? std::move(real_type) : type_),
47+
real_type_(real_type ? std::move(real_type) : TensorType::get()),
4848
N_(std::move(N)),
4949
default_value_(std::move(default_value)),
5050
alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
@@ -88,8 +88,6 @@ struct Argument {
8888
const TypePtr& type() const {
8989
return type_;
9090
}
91-
// if type() is non-null, this is guaranteed to be non-null (if no real
92-
// type was provided, this takes on type()'s value)
9391
const TypePtr& real_type() const {
9492
return real_type_;
9593
}
@@ -474,8 +472,6 @@ struct TORCH_API FunctionSchema {
474472
FunctionSchema cloneWithRemappedTypes(
475473
const std::function<TypePtr(TypePtr)> type_map) const;
476474

477-
FunctionSchema cloneWithRealTypes() const;
478-
479475
// Check that inputs have the correct types and appends any missing default
480476
// values.
481477
template <typename T = c10::PlatformType>

0 commit comments

Comments
 (0)