Skip to content

Commit

Permalink
[dynamo] Avoid recompilation when the PyTorch function accepts scalars (
Browse files Browse the repository at this point in the history
pytorch#108162)

Before, it would create a 0D tensor with the input, which would incur in
a guard and specialisation.

It's not clear whether the guard and specialisation is the right behaviour
when we create 0D tensors, but that's a story for another day.

Pull Request resolved: pytorch#108162
Approved by: https://github.com/ev-br, https://github.com/peterbell10
  • Loading branch information
lezcano authored and pytorchmergebot committed Sep 1, 2023
1 parent 591cb77 commit 2a6ef9b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 18 deletions.
18 changes: 18 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,24 @@ def fn(x, y):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)

def test_numpy_recompilation_scalar(self):
def fn(x, a):
return np.where(x < 0.5, a, x)

x = np.random.randn(8)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)

ref = fn(x, 3)
res = opt_fn(x, 3)
self.assertEqual(ref, res)

ref = fn(x, 4)
res = opt_fn(x, 4)
self.assertEqual(ref, res)

self.assertEqual(cnts.frame_count, 1)

def test_tensor_interacts_with_numpy_ndarray(self):
def fn(x, y):
a = x.numpy()
Expand Down
23 changes: 22 additions & 1 deletion torch/_numpy/_dtypes_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,52 @@ def python_type_for_torch(dtyp):

# ### NEP 50 helpers ###

SCALAR_TYPES = {int, bool, float, complex}
_SCALAR_TYPES = (int, bool, float, complex)

_SCALAR_AND_SYMBOLIC_TYPES = (
*_SCALAR_TYPES,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
)


def is_scalar(x):
return isinstance(x, _SCALAR_TYPES)


def is_scalar_or_symbolic(x):
return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES)


def _dtype_for_scalar(py_type):
return {
bool: torch.bool,
torch.SymBool: torch.bool,
int: torch.int64,
torch.SymInt: torch.int64,
float: torch.float64,
torch.SymFloat: torch.float64,
complex: torch.complex128,
}[py_type]


def _category(dtype):
return {
torch.bool: 0,
torch.SymBool: 0,
# int
torch.uint8: 1,
torch.int8: 1,
torch.int16: 1,
torch.int32: 1,
torch.int64: 1,
torch.SymInt: 1,
# float
torch.float16: 2,
torch.float32: 2,
torch.float64: 2,
torch.SymFloat: 2,
# complex
torch.complex64: 3,
torch.complex128: 3,
Expand Down
14 changes: 6 additions & 8 deletions torch/_numpy/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import _dtypes_impl, _util
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,
CastingModes,
DTypeLike,
NDArray,
Expand Down Expand Up @@ -626,8 +627,8 @@ def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):

def where(
condition: ArrayLike,
x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
x: Optional[ArrayLikeOrScalar] = None,
y: Optional[ArrayLikeOrScalar] = None,
/,
):
if (x is None) != (y is None):
Expand Down Expand Up @@ -984,8 +985,7 @@ def clip(
return torch.clamp(a, min, max)


def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
# XXX: scalar repeats; ArrayLikeOrScalar ?
def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
return torch.repeat_interleave(a, repeats, axis)


Expand Down Expand Up @@ -1553,9 +1553,7 @@ def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
if n == 0:
# no spacing argument - use 1 in all axes
dx = [1.0] * len_axes
elif n == 1 and (
type(varargs[0]) in _dtypes_impl.SCALAR_TYPES or varargs[0].ndim == 0
):
elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
# single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
dx = varargs * len_axes
elif n == len_axes:
Expand Down Expand Up @@ -1616,7 +1614,7 @@ def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
out = torch.empty_like(f, dtype=otype)

# spacing for the current axis (NB: np.ndim(ax_dx) == 0)
uniform_spacing = type(ax_dx) in _dtypes_impl.SCALAR_TYPES or ax_dx.ndim == 0
uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0

# Numerical differentiation: 2nd order interior
slice1[axis] = slice(1, -1)
Expand Down
2 changes: 1 addition & 1 deletion torch/_numpy/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def __setitem__(self, index, value):
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)

if type(value) not in _dtypes_impl.SCALAR_TYPES:
if not _dtypes_impl.is_scalar(value):
value = normalize_array_like(value)
value = _util.cast_if_needed(value, self.tensor.dtype)

Expand Down
11 changes: 9 additions & 2 deletions torch/_numpy/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,17 @@ def normalize_array_like(x, parm=None):


def normalize_array_like_or_scalar(x, parm=None):
if type(x) in _dtypes_impl.SCALAR_TYPES:
if _dtypes_impl.is_scalar_or_symbolic(x):
return x
return normalize_array_like(x, parm)


def normalize_optional_array_like_or_scalar(x, parm=None):
if x is None:
return None
return normalize_array_like_or_scalar(x, parm)


def normalize_optional_array_like(x, parm=None):
# This explicit normalizer is needed because otherwise normalize_array_like
# does not run for a parameter annotated as Optional[ArrayLike]
Expand Down Expand Up @@ -118,9 +124,10 @@ def normalize_casting(arg, parm=None):

normalizers = {
"ArrayLike": normalize_array_like,
"Union[ArrayLike, Scalar]": normalize_array_like_or_scalar,
"ArrayLikeOrScalar": normalize_array_like_or_scalar,
"Optional[ArrayLike]": normalize_optional_array_like,
"Sequence[ArrayLike]": normalize_seq_array_like,
"Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
"Optional[NDArray]": normalize_ndarray,
"Optional[OutArray]": normalize_outarray,
"NDArray": normalize_ndarray,
Expand Down
12 changes: 6 additions & 6 deletions torch/_numpy/_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

from typing import Optional, Union
from typing import Optional

import torch

from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,
CastingModes,
DTypeLike,
normalizer,
NotImplementedType,
OutArray,
Scalar,
)


Expand Down Expand Up @@ -71,8 +71,8 @@ def deco_binary_ufunc(torch_func):

@normalizer
def wrapped(
x1: Union[ArrayLike, Scalar],
x2: Union[ArrayLike, Scalar],
x1: ArrayLikeOrScalar,
x2: ArrayLikeOrScalar,
/,
out: Optional[OutArray] = None,
*,
Expand Down Expand Up @@ -145,8 +145,8 @@ def matmul(
# ldexp casting is special : the dtype of the result == dtype of the 1st arg
@normalizer
def ldexp(
x1: Union[ArrayLike, Scalar],
x2: Union[ArrayLike, Scalar],
x1: ArrayLikeOrScalar,
x2: ArrayLikeOrScalar,
/,
out: Optional[OutArray] = None,
*,
Expand Down

0 comments on commit 2a6ef9b

Please sign in to comment.