Skip to content

Commit 3c20b5b

Browse files
goldenxuettpytorchmergebot
authored andcommitted
Change autogradNotImplementedFallback to utilize is_mutable and is_aliasing from FunctionSchema (pytorch#81917)
- Simplifies autogradNotImplementedFallback methods to utilize new FunctionSchema Functionality (is_mutable and is_aliasing) - Tested through running autogradNotImplementedFallback tests in autograd.cpp Pull Request resolved: pytorch#81917 Approved by: https://github.com/soulitzer
1 parent eb2ea9a commit 3c20b5b

File tree

1 file changed

+44
-55
lines changed

1 file changed

+44
-55
lines changed

torch/csrc/autograd/autograd_not_implemented_fallback.cpp

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -53,44 +53,40 @@ void autogradNotImplementedFallbackImpl(
5353
// See gen_variable_type.py
5454
const auto& schema = op.schema();
5555
const auto& op_name = schema.operator_name().name;
56-
const auto& arguments = schema.arguments();
57-
const auto& returns = schema.returns();
58-
const auto num_arguments = arguments.size();
59-
const auto num_returns = returns.size();
56+
const auto num_arguments = schema.arguments().size();
57+
const auto num_returns = schema.returns().size();
6058
const auto stack_start = stack->size() - num_arguments;
6159
const bool grad_mode = GradMode::is_enabled();
6260
std::vector<const at::Tensor*> tensors_requiring_grad_on_stack;
6361

6462
// Keep track of which outputs are output of in-place modification
6563
// so we can rebase_history if necessary
66-
std::vector<bool> is_inplace_output;
64+
std::vector<bool> is_inplace_output(num_returns, false);
6765
bool any_is_inplace_output = false;
68-
std::vector<bool> is_aliased_output;
69-
is_inplace_output.reserve(num_returns);
70-
is_aliased_output.reserve(num_returns);
71-
72-
for (const auto i : c10::irange(num_returns)) {
73-
const at::AliasInfo* alias_info = returns[i].alias_info();
74-
is_inplace_output.push_back(alias_info != nullptr && alias_info->isWrite());
75-
any_is_inplace_output |= alias_info != nullptr && alias_info->isWrite();
76-
is_aliased_output.push_back(alias_info != nullptr);
77-
}
78-
int aliased_input_idx = -1;
66+
std::vector<bool> is_aliased_output(num_returns, false);
7967
int aliased_output_idx = -1;
68+
8069
for (const auto i : c10::irange(num_returns)) {
81-
const at::AliasInfo* alias_info = returns[i].alias_info();
82-
if (alias_info != nullptr && !alias_info->isWrite()) {
83-
TORCH_CHECK(
84-
aliased_output_idx == -1,
85-
"Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
86-
"Non-composite functions where multiple outputs are aliased with inputs aren't supported."
87-
"Please rewrite your function as a composite function.");
88-
aliased_output_idx = i;
70+
if (schema.is_aliasing({c10::SchemaArgType::output, i})) {
71+
if (schema.is_mutable({c10::SchemaArgType::output, i})) {
72+
is_inplace_output[i] = true;
73+
any_is_inplace_output = true;
74+
} else {
75+
TORCH_CHECK(
76+
aliased_output_idx == -1,
77+
"Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
78+
"Non-composite functions where multiple outputs are aliased with inputs aren't supported."
79+
"Please rewrite your function as a composite function.");
80+
aliased_output_idx = i;
81+
}
82+
is_aliased_output[i] = true;
8983
}
9084
}
85+
86+
int aliased_input_idx = -1;
9187
for (const auto i : c10::irange(num_arguments)) {
92-
const at::AliasInfo* alias_info = arguments[i].alias_info();
93-
if (alias_info != nullptr && !alias_info->isWrite()) {
88+
if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
89+
!schema.is_mutable({c10::SchemaArgType::input, i})) {
9490
TORCH_CHECK(
9591
aliased_input_idx == -1,
9692
"Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
@@ -121,8 +117,7 @@ void autogradNotImplementedFallbackImpl(
121117

122118
_foreach_tensor(
123119
[&](size_t _, size_t i, const at::Tensor& t) {
124-
const at::AliasInfo* alias_info = arguments[i].alias_info();
125-
if (alias_info != nullptr && alias_info->isWrite()) {
120+
if (schema.is_mutable({c10::SchemaArgType::input, i})) {
126121
check_inplace(t, any_requires_grad);
127122
}
128123
},
@@ -273,18 +268,16 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(
273268
// that is not allowed in the gen_inplace_or_view logic
274269
const auto& schema = op.schema();
275270
const auto& op_name = schema.operator_name().name;
276-
const auto& arguments = schema.arguments();
277-
const auto& returns = schema.returns();
278-
const auto num_arguments = arguments.size();
279-
const auto num_returns = returns.size();
271+
const auto num_arguments = schema.arguments().size();
272+
const auto num_returns = schema.returns().size();
280273
const auto stack_start = stack->size() - num_arguments;
281274

282275
at::Tensor aliased_input;
283276

284277
int64_t aliased_output_idx = -1;
285278
for (const auto i : c10::irange(num_returns)) {
286-
const at::AliasInfo* alias_info = returns[i].alias_info();
287-
if (alias_info != nullptr && !alias_info->isWrite()) {
279+
if (schema.is_aliasing({c10::SchemaArgType::output, i}) &&
280+
!schema.is_mutable({c10::SchemaArgType::output, i})) {
288281
TORCH_CHECK(
289282
aliased_output_idx == -1,
290283
"Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a "
@@ -297,25 +290,22 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(
297290

298291
int64_t aliased_input_idx = -1;
299292
for (const auto i : c10::irange(num_arguments)) {
300-
const at::AliasInfo* alias_info = arguments[i].alias_info();
301-
if (alias_info != nullptr) {
302-
if (!alias_info->isWrite()) {
303-
TORCH_CHECK(
304-
aliased_input_idx == -1,
305-
"Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a "
306-
"non-write alias annotation (i.e., 'Tensor(a)'). "
307-
"Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
308-
"Please rewrite your function as a composite function.");
309-
aliased_input_idx = i;
310-
const c10::IValue& aliased_input_iv =
311-
(*stack)[stack_start + i]; // get a reference to an ivalue on the
312-
// stack
313-
TORCH_CHECK(aliased_input_iv.isTensor());
314-
aliased_input =
315-
aliased_input_iv
316-
.toTensor(); // TODO: Can we avoid saving this tensor and
317-
// incurring the refcount bump?
318-
}
293+
if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
294+
!schema.is_mutable({c10::SchemaArgType::input, i})) {
295+
TORCH_CHECK(
296+
aliased_input_idx == -1,
297+
"Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a "
298+
"non-write alias annotation (i.e., 'Tensor(a)'). "
299+
"Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
300+
"Please rewrite your function as a composite function.");
301+
aliased_input_idx = i;
302+
const c10::IValue& aliased_input_iv =
303+
(*stack)[stack_start + i]; // get a reference to an ivalue on the
304+
// stack
305+
TORCH_CHECK(aliased_input_iv.isTensor());
306+
aliased_input =
307+
aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor
308+
// and incurring the refcount bump?
319309
}
320310
}
321311
// See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above
@@ -334,8 +324,7 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(
334324
}
335325

336326
for (const auto i : c10::irange(num_returns)) {
337-
const at::AliasInfo* alias_info = returns[i].alias_info();
338-
if (alias_info->isWrite()) {
327+
if (schema.is_mutable({c10::SchemaArgType::output, i})) {
339328
increment_version((*stack)[stack->size() - num_returns + i].toTensor());
340329
}
341330
}

0 commit comments

Comments
 (0)