Skip to content

Commit

Permalink
Put symint overloads on a different name
Browse files Browse the repository at this point in the history
Due to implicit conversion shenanigans, having both IntArrayRef
and SymIntArrayRef overloads makes {} ambiguous.  While we could
fix this by making a single unified type that accepts all the overloads
we want, an easier fix was to just push the SymIntArrayRef overload
to its own name.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: pytorch#79281

Approved by: https://github.com/suo
  • Loading branch information
ezyang authored and pytorchmergebot committed Jun 12, 2022
1 parent f23685f commit 213a8fc
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3394,7 +3394,7 @@ at::Tensor& diagonal_copy_out(const at::Tensor & self, int64_t offset, int64_t d


at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out) {
auto tmp = self.expand(size, implicit);
auto tmp = self.expand_symint(size, implicit);
out.copy_(tmp);
return out;
}
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ TEST(TestSymIntArrayRef, BasicConversion) {
std::vector<int64_t> tgt_size_v{2, 4, 5};
std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
auto a = at::randn({1, 4, 1}, at::kCPU);
auto b = a.expand(tgt_size);
auto b = a.expand_symint(tgt_size);
auto c = a.expand(tgt_size_v);
ASSERT_TRUE(torch::allclose(b, c));
}
Expand All @@ -1395,7 +1395,7 @@ TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
static const size_t LENGTH = 5;
auto a = at::randn({10}, at::kCPU);
c10::SymInt si(LENGTH);
auto b = a.narrow_copy(0, 0, si);
auto b = a.narrow_copy_symint(0, 0, si);
auto c = a.narrow(0, 0, LENGTH);
ASSERT_TRUE(torch::allclose(b, c));
}
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/lazy/test_lazy_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
auto ly = torch::lazy::TryGetLtcTensor(y);
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
auto lmn = std::make_shared<torch::lazy::SymbolicIntNode>(dim_node);
auto z = x.narrow_copy(X_DIM_INDEX, 0, lmn->toSymInt());
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt());
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
}

Expand All @@ -100,7 +100,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 0;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto z = x.narrow_copy(X_DIM_INDEX, 0, y.sym_sizes()[0]);
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, y.sym_sizes()[0]);
auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM);
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
// shape inference assumes narrow_copy can copy the whole tensor
Expand Down
10 changes: 9 additions & 1 deletion tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ def repl(m: Any) -> str:
arg_name = arg_name + "_t"
new_args.append(arg_name)

# TODO we are trolling
if f.func.is_symint_fn():
defn_name += "_symint"

# Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
if Variant.function in f.variants:
fw_formula = "at::{}({})".format(defn_name, ", ".join(new_args))
Expand Down Expand Up @@ -396,7 +400,11 @@ def canonical_function(
functions: Sequence[NativeFunction], name: str
) -> NativeFunction:
for f in functions:
if cpp.name(f.func) == name:
if (
not f.func.is_functional_fn()
and not f.func.is_out_fn()
and name == str(f.func.name.name)
):
return f
# some functions only have in-place variants
assert name + "_" == cpp.name(functions[0].func)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/core/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ std::vector<Shape> compute_shape_repeat(
return {Shape(self.scalar_type(), target_size)};
}

std::vector<Shape> compute_shape_narrow_copy(
std::vector<Shape> compute_shape_narrow_copy_symint(
const at::Tensor& self,
int64_t dim,
int64_t start,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/core/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & s
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);

// Non-Native ops
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
Expand Down
4 changes: 3 additions & 1 deletion torchgen/api/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False)
name = str(func.name.name)
if func.is_functional_fn():
name += "_functional"
elif func.is_out_fn():
if func.is_symint_fn():
name += "_symint"
if func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:
Expand Down
4 changes: 4 additions & 0 deletions torchgen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,10 @@ def __post_init__(self) -> None:
def is_functional_fn(self) -> bool:
return "functional" in self.name.overload_name

def is_symint_fn(self) -> bool:
# TODO: make this more robust
return "SymInt" in self.name.overload_name

def is_out_fn(self) -> bool:
# Note [is_out_fn]
#
Expand Down

0 comments on commit 213a8fc

Please sign in to comment.