Skip to content

Commit

Permalink
Allow inferring size-nature from sizes passed to empty constructor (p…
Browse files Browse the repository at this point in the history
…ytorch#109720)

This removes the need for many constrain_as_size calls as we now
infer them from error checking for sizes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#109720
Approved by: https://github.com/aakhundov
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 21, 2023
1 parent 6ca964b commit 09622d8
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 10 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/EmptyTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ inline void check_size_nonnegative(ArrayRef<int64_t> size) {

inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
for (const auto& x : size) {
TORCH_SYM_CHECK(
x.sym_ge(0),
TORCH_CHECK(
x.expect_size(__FILE__, __LINE__),
"Trying to create tensor with negative dimension ",
x,
": ",
Expand Down
8 changes: 8 additions & 0 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ int64_t SymInt::guard_int(const char* file, int64_t line) const {
}
}

bool SymInt::expect_size(const char* file, int64_t line) const {
if (auto ma = maybe_as_int()) {
return *ma >= 0;
} else {
return toSymNodeImplUnowned()->expect_size(file, line);
}
}

SymInt operator-(const SymInt& s) {
if (auto ma = s.maybe_as_int()) {
return SymInt(-*ma);
Expand Down
8 changes: 8 additions & 0 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class C10_API SymInt {
// number can be used to diagnose overspecialization.
int64_t guard_int(const char* file, int64_t line) const;

// Insert a guard that this SymInt must be size-like, returning true if
// the integer actually is >= 0. Unlike manually performing a >= 0 test,
// if the SymInt in question is an unbacked SymInt (or, potentially in the
// future, if it contains unbacked SymInts), we will also treat the
// unbacked SymInt as statically testing >= 2 (which will prevent us from
// choking on, e.g., contiguity chekcs.)
bool expect_size(const char* file, int64_t line) const;

// Distinguish actual symbolic values from constants stored on the heap
bool is_symbolic() const {
return is_heap_allocated() &&
Expand Down
5 changes: 5 additions & 0 deletions c10/core/SymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
// with a better implementation!
return guard_bool(file, line);
};
virtual bool expect_size(const char* file, int64_t line) {
// No improvement for unbacked SymInts by default, replace this
// with a better implementation!
return ge(wrap_int(0))->guard_bool(file, line);
};
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
};
Expand Down
3 changes: 3 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7274,6 +7274,9 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
==> name_to_node: values don't match.
> Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {}
==> runtime_var_to_range: values don't match.
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Right: {}
==> source_to_symbol: values don't match.
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
> Right: {}
Expand Down
4 changes: 3 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,9 @@ def test_constrain_as_size_error(self):

def f(x):
a = x.item()
return torch.full((a, 4), 0)
# We cannot automatically infer a is a size here because view
# accepts -1
return torch.randn(24).view(a, 4)

with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
Expand Down
26 changes: 21 additions & 5 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,22 @@ def forward(self, y_1, x_1):
index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None
return index_select""")

def test_repeat_interleave_unbacked_output_size(self):
def f(x, y):
s = x.sum().item()
return y.repeat_interleave(x, dim=0, output_size=s)

r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, x_1, y_1):
sum_1 = torch.ops.aten.sum.default(x_1)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None
index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None
return index_select""" # noqa: B950
)

def test_adv_index_batch(self):
def f(src_tokens):
bsz, src_len = src_tokens.size()[:2]
Expand Down Expand Up @@ -1067,15 +1083,13 @@ def forward(self, a_1):
def test_item_to_constructor(self):
def f(a):
r = a.item()
constrain_as_size(r)
return torch.empty(r)

r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, a_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None)
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return empty""" # noqa: B950
)
Expand Down Expand Up @@ -1325,15 +1339,17 @@ def f(a, b):
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env

def test_size_with_tensor(self):
# I think I messed up writing this test case originally, I think
# I'm supposed to hit an error case, but the code here works in both
# eager and tracing
def f(tensor):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
return tensor.new_empty(batch_shape)

a = torch.randn(3, 800, 1199)
self.assertRaisesRegex(
RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a)
)
f(a)
make_fx(f, tracing_mode="symbolic")(a)

def test_expand(self):
def f(a):
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,11 @@ void initJITBindings(PyObject* module) {
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_true(file, line);
})
.def(
"expect_size",
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_size(file, line);
})
.def(
"has_hint",
[](c10::SymNode a) {
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/utils/python_symnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return getPyObj().attr("expect_true")(file, line).cast<bool>();
}

bool expect_size(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("expect_size")(file, line).cast<bool>();
}

int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("int_")().cast<int64_t>();
Expand Down
22 changes: 20 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def guard_scalar(a):

@record_shapeenv_event()
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int):
log.debug("_constrain_symbol_range %s [%s, %s] [%s, %s]", s, compiler_min, compiler_max, runtime_min, runtime_max)
if r := shape_env.var_to_range.get(s, None):
shape_env.var_to_range[s] = ValueRanges(
builtins.max(r.lower, compiler_min), builtins.min(r.upper, compiler_max)
Expand Down Expand Up @@ -1019,6 +1020,18 @@ def expect_true(self, file, line):
# deferred so you can't backtrace easily
return self.shape_env.defer_runtime_assert(self.expr, f"{file}:{line}", fx_node=self.fx_node)

def expect_size(self, file, line):
b = self.ge(self.wrap_int(0))
# Generate a deferred runtime assert
r = b.expect_true(file, line)
# Refine compile time range, but only if it's unbacked.
# If you refine range for hinted variables, you can end up making
# improper deductions since compile time reasoning may be
# incompatible with runtime reasoning.
if r and not self.has_hint():
_advise_is_size(SymInt(self))
return r

def bool_(self):
return self.guard_bool("", 0)

Expand Down Expand Up @@ -2839,8 +2852,6 @@ def create_symbol(
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True)
self.log.info("create_symbol %s = %s for %s", sympy_expr, val, source.name())
self.counter["create_symbol"] += 1
# We always associate vars to vals
self.var_to_val[sympy_expr] = sympy.Integer(val)
# Do the appending later, because we always want to populate this
Expand Down Expand Up @@ -2871,7 +2882,14 @@ def create_symbol(
if val not in vr:
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")

# Initialize default runtime range to match compile time range,
# for backed SymInts (this is allowed to diverge for unbacked)
self.runtime_var_to_range[sympy_expr] = vr

r = sympy_expr

self.log.info("create_symbol %s = %s for %s [%s, %s]", sympy_expr, val, source.name(), vr.lower, vr.upper)
self.counter["create_symbol"] += 1
else:
# This implements duck-shaping: input sizes that match are assigned
# the same symint
Expand Down

0 comments on commit 09622d8

Please sign in to comment.