Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -645,16 +645,12 @@ def create_zip(*var: VariableBase):


# map
Dispatcher.register(
map,
(
"CallableVariable",
"VariableBase",
),
lambda fn, var: MapVariable.from_iterator(
fn, var, graph=var.graph, tracker=DummyTracker([var])
),
)
@Dispatcher.register_decorator(map)
def create_map(func: CallableVariable, *var: VariableBase):
tracked_vars = [func, *var]
return MapVariable.from_iterator(
func, var, graph=Dispatcher.graph, tracker=DummyTracker(tracked_vars)
)


# reversed
Expand Down
46 changes: 27 additions & 19 deletions python/paddle/jit/sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..tracker import ConstTracker, DummyTracker
from .base import VariableFactory
from .basic import ConstantVariable
from .container import ContainerVariable, TupleVariable
from .container import TupleVariable

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -226,21 +226,25 @@ class MapVariable(SequenceIterVariable):
MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function
"""

def __init__(self, func, val_iterator, graph, tracker):
super().__init__(val_iterator, graph, tracker)
def __init__(self, func, iters, graph, tracker):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, func, iters, graph, tracker):
def __init__(self, fn, iters: tuple[VariableBase, ...], graph, tracker):

# iters may contain only one iter.
super().__init__(iters, graph, tracker)
self.func = func

def next(self):
return self.func(self.hold.next())
values = []
for iter_var in self.hold:
next_var = iter_var.next()
values.append(next_var)
return self.func(*values)

def to_list(self) -> list:
retval = []
while True:
try:
retval.append(self.func(self.hold.next()))
except StopIteration:
break
return retval
lists = [iter_vars.to_list() for iter_vars in self.hold]
min_len = min(len(l) for l in lists)
result = []
for i in range(min_len):
result.append(self.func(*(l[i] for l in lists)))
return result

def has_side_effect(self) -> bool:
return self.hold.has_side_effect()
Expand All @@ -256,16 +260,20 @@ def _reconstruct(self, codegen: PyCodeGen):

@staticmethod
def from_iterator(
func, value, graph: FunctionGraph | None, tracker: Tracker
func,
value: Sequence[VariableBase],
graph: FunctionGraph | None,
tracker: Tracker,
):
iter_variable = (
value.get_iter() if isinstance(value, ContainerVariable) else value
)
map_targets = []

if isinstance(iter_variable, IterVariable):
return MapVariable(func, iter_variable, graph, tracker)
else:
return UserDefinedIterVariable(value, graph, tracker)
for variable in value:
iter_variable = variable.get_iter()
if not isinstance(iter_variable, SequenceIterVariable):
return UserDefinedIterVariable(value, graph, tracker)
map_targets.append(iter_variable)

return MapVariable(func, map_targets, graph, tracker)


# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph
Expand Down
13 changes: 13 additions & 0 deletions test/sot/test_builtin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from test_case_base import TestCaseBase, test_with_faster_guard

from paddle import to_tensor
from paddle.jit import sot
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import strict_mode_guard
Expand Down Expand Up @@ -98,6 +99,12 @@ def test_map_for_loop(x: list):
return res


@check_no_breakgraph
def test_map_multi_input(func, tensor_, tuple_):
x, y, z = map(func, tensor_, tuple_)
return x


class TestMap(TestCaseBase):
@test_with_faster_guard
def test_map(self):
Expand All @@ -122,6 +129,12 @@ def test_map_with_breakgraph(self):
@test_with_faster_guard
def test_map_unpack(self):
self.assert_results(test_map_unpack, [1, 2, 3, 4])
self.assert_results(
test_map_multi_input,
lambda x, y: x + y,
to_tensor([1, 2, 3]),
(2, 4, 6),
)

@test_with_faster_guard
def test_map_for_loop(self):
Expand Down