Skip to content

Commit

Permalink
[test] attempt to functionalize ops with mutable positional-only args
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#76320

Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 19, 2022
1 parent b8639cf commit 0161e9e
Show file tree
Hide file tree
Showing 22 changed files with 1,285 additions and 278 deletions.
109 changes: 107 additions & 2 deletions aten/src/ATen/native/native_functions.yaml

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions aten/src/ATen/native/ts_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ full_codegen:
- avg_pool2d_backward
- baddbmm
- bernoulli
- bernoulli_.float
- bernoulli.p
- binary_cross_entropy
- binary_cross_entropy_backward
- bitwise_and.Tensor
Expand Down Expand Up @@ -72,8 +72,8 @@ full_codegen:
- log_sigmoid_forward
- lt.Scalar
- lt.Tensor
- masked_fill_.Scalar
- masked_fill_.Tensor
- masked_fill.Scalar
- masked_fill.Tensor
- max
- max.dim
- max_pool2d_with_indices
Expand Down Expand Up @@ -101,12 +101,11 @@ full_codegen:
- norm.ScalarOpt_dim
- pow.Tensor_Scalar
- pow.Tensor_Tensor
- random_
- random_.from
- random_.to
- random.functional
- random.from_functional
- random.to_functional
- reciprocal
- relu
- relu_
- remainder.Tensor
- repeat
- rsqrt
Expand Down Expand Up @@ -141,7 +140,7 @@ full_codegen:
- upsample_bilinear2d_backward
- upsample_nearest2d
- upsample_nearest2d_backward
- zero_
- zero.functional
- narrow_copy.SymInt
supported:
- as_strided
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/templates/CompositeViewCopyKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,25 @@
namespace at {
namespace native {

// This file contains a number of kernels for aten functions that are fully code-generated.
// TODO: rename this file to something more generic.

at::Tensor clone_arg(const at::Tensor& t) {
return t.clone();
}

std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
std::vector<at::Tensor> out(t_list.size());
for (const auto& i : c10::irange(t_list.size())) {
out[i] = t_list[i].clone();
}
return out;
}


${CompositeViewCopyKernel_Definitions}

${GeneratedCompositeFunctional_Definitions}

} // namespace native
} // namespace at
4 changes: 3 additions & 1 deletion aten/src/ATen/test/atest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ TEST_F(atest, ne_operators) {

TEST_F(atest, add_operators) {
auto exp_tensor = tensor({-10, 1, 0, -1, 10});
run_binary_ops_test(add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
run_binary_ops_test<
at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Scalar&)>(
add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
}

TEST_F(atest, max_operators) {
Expand Down
11 changes: 11 additions & 0 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ def f(x):
[1., 1.],
[1., 1.]]))""")

# Some ops that are mutable are neither inplace nor out= ops.
# They also need special handling.
def test_mutable_op_not_inplace_or_other(self):
def f(x):
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)

logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")

def test_tensor_list_composite(self):
def f(x):
# Test an op with TensorList input
Expand Down
2 changes: 0 additions & 2 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3944,8 +3944,6 @@ def tearDown(self):

"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,

"normalize" : MUTABLE,
}

# List of nn.functionals with Tensor inputs but not with type annotation
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@

@with_native_function
def should_generate_py_binding(f: NativeFunction) -> bool:
# So far, all NativeFunctions that are entirely code-generated do not get python bindings.
if "generated" in f.tags:
return False
name = cpp.name(f.func)
for skip_regex in SKIP_PYTHON_BINDINGS:
if skip_regex.match(name):
Expand Down
3 changes: 2 additions & 1 deletion tools/autograd/gen_trace_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def emit_trace_body(f: NativeFunction) -> List[str]:

assign_return_values = (
f"{tie_return_values(f)} = "
if f.func.kind() == SchemaKind.functional and f.func.returns
if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
and f.func.returns
else ""
)

Expand Down
15 changes: 11 additions & 4 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)

# Some operators invalidate the grad_accumulator. Let's reset it.
RESET_GRAD_ACCUMULATOR = {"set", "resize"}
RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}

# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
#
Expand Down Expand Up @@ -734,7 +734,7 @@ def gen_variable_type_func(

if (
fn.info is None
and not get_base_name(f) in RESET_GRAD_ACCUMULATOR
and not str(f.func.name.name) in RESET_GRAD_ACCUMULATOR
and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE
and len(gen_differentiable_outputs(fn)) > 0
and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
Expand Down Expand Up @@ -857,7 +857,14 @@ def find_args_with_derivatives(
and (len(differentiable_outputs) > 0)
)

if info is not None and info.has_derivatives and not requires_derivative:
if (
info is not None
and info.has_derivatives
and not requires_derivative
# out= ops are allowed to have zero returns which cause requires_derivative to be False
# we shouldn't error out though (out= ops for autograd just redispatch)
and len(f.func.returns) > 0
):
raise RuntimeError(
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
)
Expand Down Expand Up @@ -1528,7 +1535,7 @@ def get_msg() -> str:
# Save only after the forward AD has been set up
body.append(emit_save_outputs())

if base_name in RESET_GRAD_ACCUMULATOR:
if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
# `inplace` implies that there is exactly one output named `self`,
# so we can keep the generated code easy. If you need to
# `reset_grad_accumulator` in an operator that's not `inplace`, you can
Expand Down
5 changes: 4 additions & 1 deletion tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
stringT,
)
from torchgen.api import cpp
from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions
from torchgen.gen import (
parse_native_yaml,
get_grouped_by_view_native_functions,
)
from torchgen.context import with_native_function
from torchgen.model import (
FunctionSchema,
Expand Down
26 changes: 11 additions & 15 deletions torch/csrc/lazy/core/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ std::vector<int64_t> expand_param_if_needed(
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"

std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
TORCH_API std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
double size_d = 0;
// shape inference code copied from RangeFactories.cpp arange_out function
// Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct c++ scalar_t type depending on out tensor

AT_DISPATCH_ALL_TYPES_AND(c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
// Note: acc_type further defines an accumulataion type depending on the scalar_t and whether its on cuda vs cpu.
using accscalar_t = at::acc_type<scalar_t, false>;
Expand Down Expand Up @@ -129,7 +130,6 @@ std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::
// If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
// Otherwise, the dtype is inferred to be torch.int64.

// Since out tensor is specified, its dtype should always be used?
return {Shape(out.scalar_type(), {size})};
}

Expand All @@ -145,7 +145,7 @@ std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optiona
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator) {
std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator) {
return compute_shape_bernoulli(self, generator);
}

Expand Down Expand Up @@ -224,11 +224,11 @@ std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at:
}
}

std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

Expand Down Expand Up @@ -380,26 +380,22 @@ std::vector<Shape> compute_shape_native_dropout_backward(const at::Tensor & grad
return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
}

std::vector<Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator) {
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
return compute_shape_random_(self, generator);
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
}

std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
return compute_shape_random_(self, generator);
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
}

std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<Shape> compute_shape_relu_(at::Tensor& self) {
return compute_shape_relu(self);
}

std::vector<Shape> compute_shape_bitwise_and(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
Expand All @@ -417,7 +413,7 @@ std::vector<Shape> compute_shape_sum(
return {Shape(self.scalar_type(), {})};;
}

std::vector<Shape> compute_shape_zero_(at::Tensor& self) {
std::vector<Shape> compute_shape_zero_functional(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

Expand Down
15 changes: 7 additions & 8 deletions torch/csrc/lazy/core/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool2d_bac
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
Expand All @@ -37,8 +37,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_l1_loss_backward(const a
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_forward(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logdet(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
Expand All @@ -50,11 +50,10 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backwa
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu_(at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending);
Expand All @@ -65,7 +64,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & s
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_(at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
} // namespace lazy
} // namespace torch
34 changes: 29 additions & 5 deletions torchgen/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,41 @@ def match_differentiability_info(
for info in differentiability_infos
if info.func.func.kind() == SchemaKind.functional
}
non_functional_info_by_signature = {
info.func.func.signature(strip_default=True): info
for info in differentiability_infos
if info.func.func.kind() != SchemaKind.functional
}

def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]:
# (1) Check for an exact match
if f.func in info_by_schema:
return info_by_schema[f.func], True

# if there is no exact match look for the out-of-place signature.
# (2) If no exact match, check if the out-of-place variant
# of this operator has a match.
# i.e mul() for mul_() or mul_out()
return (
functional_info_by_signature.get(f.func.signature(strip_default=True)),
False,
)
f_sig = f.func.signature(strip_default=True)
if f_sig in functional_info_by_signature:
return functional_info_by_signature[f_sig], False

# (3) Some operators have a derivative explicitly defined for the mutable
# variant, but get a code-generated out-of-place variant which does *not*
# come with a derivative formula.
# For the generated out-of-place variant, use the mutable variant's formula
# if it exists.
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
info = non_functional_info_by_signature[f_sig]
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
assert not any(
"self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs
), f"""\
Attempted to convert a derivative formula for a mutable operator
to be used by automatically by its functional variant ("{str(f.func)}").
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info, False

return None, False

result: List[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
Expand Down
4 changes: 3 additions & 1 deletion torchgen/api/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@

def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if func.is_out_fn():
if func.is_functional_fn():
name += "_functional"
elif func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:
Expand Down
Loading

0 comments on commit 0161e9e

Please sign in to comment.