Skip to content

Commit

Permalink
Add exhaustive testing to ValueRanges, fix bugs (pytorch#94939)
Browse files Browse the repository at this point in the history
Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#94939
Approved by: https://github.com/lezcano, https://github.com/eellison, https://github.com/nunoplopes
  • Loading branch information
ezyang authored and pytorchmergebot committed Feb 17, 2023
1 parent 12c9a93 commit 08ef83f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 28 deletions.
91 changes: 82 additions & 9 deletions test/test_value_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@
2**32,
2**37 - 1,
]
# less constants for N^2 situations
LESS_CONSTANTS = [-1, 0, 1, 2, 100]


# The normal Python interpretation of the operators
# TODO: maybe make this work with sympy?
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@staticmethod
def reciprocal(x):
Expand Down Expand Up @@ -119,15 +122,45 @@ def ceil(x):
return math.ceil(x)


def valid_unary(fn, v):
if fn == "log" and v <= 0:
return False
if fn == "reciprocal" and v == 0:
return False
if fn == "sqrt" and v < 0:
return False
return True


def valid_binary(fn, a, b):
if fn == "pow" and (
b > 4
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
a <= 0 and b == -1
)
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
):
return False
if (fn == "div" or fn == "truediv") and b == 0:
return False
return True


def generate_range(vals):
for a1, a2 in itertools.product(vals, repeat=2):
if a1 > a2:
continue
# ranges that only admit infinite values are not interesting
if a1 == sympy.oo or a2 == -sympy.oo:
continue
yield ValueRanges(a1, a2)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
def test_unary_ref(self, fn):
for v in CONSTANTS:
if fn == "log" and v <= 0:
continue
if fn == "reciprocal" and v == 0:
continue
if fn == "sqrt" and v < 0:
if not valid_unary(fn, v):
continue
with self.subTest(v=v):
ref_r = getattr(ReferenceAnalysis, fn)(sympy.Integer(v))
Expand All @@ -138,9 +171,7 @@ def test_unary_ref(self, fn):
@parametrize("fn", BINARY_OPS)
def test_binary_ref(self, fn):
for a, b in itertools.product(CONSTANTS, repeat=2):
if fn == "pow" and (b > 4 or b == -1 or (a == b == 0)):
continue
if (fn == "div" or fn == "truediv") and b == 0:
if not valid_binary(fn, a, b):
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ReferenceAnalysis, fn)(
Expand All @@ -153,6 +184,48 @@ def test_binary_ref(self, fn):
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)

def test_mul_zero_unknown(self):
self.assertEqual(
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
ValueRanges.wrap(0),
)

@parametrize("fn", UNARY_OPS)
def test_unary_ref_range(self, fn):
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
for a in generate_range(vals):
with self.subTest(a=a):
ref_r = getattr(ValueRangeAnalysis, fn)(a)
for a0 in CONSTANTS:
if a0 not in a:
continue
if not valid_unary(fn, a0):
continue
with self.subTest(a0=a0):
r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0))
self.assertIn(r, ref_r)

# This takes about 4s for all the variants
@parametrize("fn", BINARY_OPS)
def test_binary_ref_range(self, fn):
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
for a, b in itertools.product(generate_range(vals), repeat=2):
# don't attempt pow on exponents that are too large (but oo is OK)
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
if a0 not in a or b0 not in b:
continue
if not valid_binary(fn, a0, b0):
continue
with self.subTest(a0=a0, b0=b0):
r = getattr(ReferenceAnalysis, fn)(
sympy.Integer(a0), sympy.Integer(b0)
)
self.assertIn(r, ref_r)


instantiate_parametrized_tests(TestValueRanges)

Expand Down
39 changes: 21 additions & 18 deletions torch/utils/_sympy/value_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def simple_sympify(e):
elif isinstance(e, BooleanAtom):
return e
else:
raise AssertionError(f"not simple sympy type {type(e)}")
raise AssertionError(f"not simple sympy type {type(e)}: {e}")

# Sympy atomics only. Unlike <=, it also works on Sympy bools.
def sympy_generic_le(lower, upper):
Expand Down Expand Up @@ -268,31 +268,34 @@ def exp(x):

@staticmethod
def log(x):
return ValueRanges.increasing_map(
x, lambda y: -sympy.oo if y <= 0 else sympy.log(y)
)
if x.lower <= 0:
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.log)

@staticmethod
def sqrt(x):
if x.lower < 0:
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.sqrt)

@staticmethod
def pow(a, b):
def is_integer(val):
return (
isinstance(val, int)
or (isinstance(val, float) and val == int(val))
or (hasattr(val, "is_integer") and val.is_integer)
)

@classmethod
def pow(cls, a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if a.lower < 0 and not is_integer(b.lower):
# The function is not defined
return ValueRanges.unknown()
elif 0 in a and b.lower <= 0:
if a.lower == a.upper and b.lower == b.upper:
r = a.lower ** b.lower
if r == sympy.zoo:
return ValueRanges.unknown()
return ValueRanges.wrap(r)
elif b.lower == b.upper and b.lower >= 0:
i = ValueRanges.wrap(1)
for _ in range(b.lower):
i = cls.mul(i, a)
return i
else:
# This is fairly difficult to analyze, so give up for anything
# complicated
return ValueRanges.unknown()
return ValueRanges.coordinatewise_monotone_map(a, b, operator.pow)

@staticmethod
def minimum(a, b):
Expand Down

0 comments on commit 08ef83f

Please sign in to comment.