Skip to content

Commit

Permalink
[dynamo] refactor builtins.enumerate to use polyfill (pytorch#133894)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#133894
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#133864
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Aug 31, 2024
1 parent ebbdeee commit a854c3a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
20 changes: 19 additions & 1 deletion torch/_dynamo/polyfills/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
Python polyfills for builtins
"""

from __future__ import annotations

import builtins
from typing import Iterable
from typing import Iterable, TypeVar

from ..decorators import substitute_in_graph


__all__ = [
"all",
"any",
"enumerate",
]


_T = TypeVar("_T")


@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
for elem in iterable:
Expand All @@ -28,3 +34,15 @@ def any(iterable: Iterable[object], /) -> bool:
if elem:
return True
return False


@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type]
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
if not isinstance(start, int):
raise TypeError(
f"{type(start).__name__!r} object cannot be interpreted as an integer"
)

for x in iterable:
yield start, x
start += 1
16 changes: 0 additions & 16 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,22 +1442,6 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
return variables.TupleVariable(items)

def call_enumerate(self, tx: "InstructionTranslator", *args):
if len(args) == 1:
start = 0
else:
assert len(args) == 2
assert isinstance(args[1], variables.ConstantVariable)
start = args[1].as_python_constant()
if args[0].has_unpack_var_sequence(tx):
items = [
variables.TupleVariable(
[variables.ConstantVariable.create(idx), var],
)
for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
]
return variables.TupleVariable(items)

def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs)

Expand Down

0 comments on commit a854c3a

Please sign in to comment.