Skip to content

Commit

Permalink
[dynamo] simplify implementation for functools.reduce (pytorch#133778)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#133778
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#133712, pytorch#133769
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Aug 20, 2024
1 parent 178e856 commit 37b4bc6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion torch/_dynamo/polyfills/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Python polyfills for C functions.
"""

from . import builtins # noqa: F401
from . import builtins, functools # noqa: F401
43 changes: 43 additions & 0 deletions torch/_dynamo/polyfills/functools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Python polyfills for functools
"""

import functools
from typing import Callable, Iterable, TypeVar

from ..decorators import substitute_in_graph


_T = TypeVar("_T")
_U = TypeVar("_U")


class _INITIAL_MISSING:
pass


# Reference: https://docs.python.org/3/library/functools.html#functools.reduce
@substitute_in_graph(functools.reduce)
def reduce(
function: Callable[[_U, _T], _U],
iterable: Iterable[_T],
initial: _U = _INITIAL_MISSING, # type: ignore[assignment]
/,
) -> _U:
it = iter(iterable)

value: _U
if initial is _INITIAL_MISSING:
try:
value = next(it) # type: ignore[assignment]
except StopIteration:
raise TypeError(
"reduce() of empty iterable with no initial value",
) from None
else:
value = initial

for element in it:
value = function(value, element)

return value
1 change: 0 additions & 1 deletion torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,7 +2996,6 @@ def _builtin_function_ids() -> Dict[int, str]:
rv.update(
{
id(cast): "typing.cast",
id(functools.reduce): "functools.reduce",
id(copy.deepcopy): "copy.deepcopy",
}
)
Expand Down
18 changes: 4 additions & 14 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,10 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL):
if start is self._SENTINEL:
start = variables.ConstantVariable.create(0)
items = seq.unpack_var_sequence(tx)
return BuiltinVariable(functools.reduce).call_function(

from .builder import SourcelessBuilder

return SourcelessBuilder.create(tx, functools.reduce).call_function(
tx,
[
BuiltinVariable(operator.add),
Expand All @@ -1616,19 +1619,6 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL):
{},
)

def call_reduce(
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
):
if iterable.has_unpack_var_sequence(tx):
items = iterable.unpack_var_sequence(tx)
if initial is self._SENTINEL:
value, items = items[0], items[1:]
else:
value = initial
for element in items:
value = function.call_function(tx, [value, element], {})
return value

def call_getattr(
self,
tx: "InstructionTranslator",
Expand Down

0 comments on commit 37b4bc6

Please sign in to comment.