Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Rename self.hold to self.holds in iter.py #71473

Merged
merged 3 commits into from
Mar 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 38 additions & 39 deletions python/paddle/jit/sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ class IterVariable(VariableBase):
"""

def __init__(
self, holds: list[VariableBase], graph: FunctionGraph, tracker: Tracker
self, holded: list[VariableBase], graph: FunctionGraph, tracker: Tracker
):

super().__init__(graph, tracker)
self.hold = holds
self.holds = holded

def make_stringified_guard(self):
return [
result
for holds in self.hold
for result in holds.make_stringified_guard()
for holded in self.holds
for result in holded.make_stringified_guard()
]

def next(self):
Expand All @@ -60,11 +60,11 @@ def get_iter(self):
return self

def flatten_inner_vars(self) -> list[VariableBase]:
holds = self.hold
holded = self.holds
return [
inner_var
for hold in holds
for inner_var in hold.flatten_inner_vars()
for obj in holded
for inner_var in obj.flatten_inner_vars()
]


Expand All @@ -84,21 +84,20 @@ class SequenceIterVariable(IterVariable):

def __init__(
self,
holds: VariableBase | list[VariableBase],
holded: VariableBase | list[VariableBase],
graph: FunctionGraph,
tracker: Tracker,
):
if not isinstance(holds, list):
holds = [holds]
super().__init__(holds, graph, tracker)
if not isinstance(holded, list):
holded = [holded]
super().__init__(holded, graph, tracker)
self.idx = 0
self.graph.side_effects.record_mutable_variable(self)

def next(self):
holds = self.hold[0]
# TODO: self.hold should have a __len__ method
if self.idx < len(holds):
val = holds[self.idx]
holded = self.holds[0]
if self.idx < len(holded):
val = holded[self.idx]
self.idx += 1
return val
else:
Expand All @@ -107,11 +106,11 @@ def next(self):
def to_list(self) -> list:
if self.has_side_effect():
raise FallbackError("Can not convert an used iterator into list")
holds = self.hold[0]
self.idx = len(holds)
holded = self.holds[0]
self.idx = len(holded)
retval = []
for i in range(len(holds)):
retval.append(holds[i])
for i in range(len(holded)):
retval.append(holded[i])
return retval

def has_side_effect(self) -> bool:
Expand All @@ -121,7 +120,7 @@ def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
self.hold[0].reconstruct(codegen)
self.holds[0].reconstruct(codegen)
codegen.gen_get_iter()

@property
Expand All @@ -142,30 +141,30 @@ def __init__(
super().__init__(val_iterator, graph, tracker)

def next(self):
val = self.hold[0].next()
val = self.holds[0].next()
idx_var = ConstantVariable(self.idx, self.graph, ConstTracker(self.idx))
self.idx += 1
return TupleVariable(
(idx_var, val), self.graph, DummyTracker([idx_var, val])
)

def to_list(self):
values = self.hold[0].to_list()
values = self.holds[0].to_list()
idx = [
ConstantVariable(i, self.graph, ConstTracker(i))
for i in range(len(values))
]
return list(zip(idx, values))

def has_side_effect(self) -> bool:
return self.hold[0].has_side_effect()
return self.holds[0].has_side_effect()

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("enumerate", push_null=True)
self.hold[0].reconstruct(codegen)
self.holds[0].reconstruct(codegen)
codegen.gen_call_function(1)

@staticmethod
Expand All @@ -191,7 +190,7 @@ def next(self):
# can not use <listcomp> here, because it will raise a RuntimeError("StopIteration")
# but we want a StopIteration Exception
values = []
for iter_var in self.hold:
for iter_var in self.holds:
next_var = iter_var.next()
values.append(next_var)

Expand All @@ -200,30 +199,30 @@ def next(self):
)

def to_list(self):
lists = [iter_vars.to_list() for iter_vars in self.hold]
lists = [iter_vars.to_list() for iter_vars in self.holds]
min_len = min(len(l) for l in lists)
result = []
for i in range(min_len):
result.append(
VariableFactory.from_value(
tuple(l[i] for l in lists),
self.graph,
DummyTracker(list(self.hold)),
DummyTracker(list(self.holds)),
)
)
return result

def has_side_effect(self) -> bool:
return any(iter_var.has_side_effect() for iter_var in self.hold)
return any(iter_var.has_side_effect() for iter_var in self.holds)

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("zip", push_null=True)
for iter_var in self.hold:
for iter_var in self.holds:
iter_var.reconstruct(codegen)
codegen.gen_call_function(len(self.hold))
codegen.gen_call_function(len(self.holds))

@staticmethod
def from_iterator(
Expand Down Expand Up @@ -255,28 +254,28 @@ def __init__(self, fn, iters: list[IterVariable], graph, tracker):

def next(self):

return self.fn(*[iter_var.next() for iter_var in self.hold])
return self.fn(*[iter_var.next() for iter_var in self.holds])

def to_list(self) -> list:
lists = [iter_var.to_list() for iter_var in self.hold]
lists = [iter_var.to_list() for iter_var in self.holds]
min_len = min(len(l) for l in lists)
result = []
for i in range(min_len):
result.append(self.fn(*(l[i] for l in lists)))
return result

def has_side_effect(self) -> bool:
return any(iter_var.has_side_effect() for iter_var in self.hold)
return any(iter_var.has_side_effect() for iter_var in self.holds)

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("map", push_null=True)
self.fn.reconstruct(codegen)
for iter_var in self.hold:
for iter_var in self.holds:
iter_var.reconstruct(codegen)
codegen.gen_call_function(len(self.hold) + 1)
codegen.gen_call_function(len(self.holds) + 1)

@staticmethod
def from_iterator(
Expand All @@ -300,13 +299,13 @@ def from_iterator(
class UserDefinedIterVariable(IterVariable):
def __init__(
self,
holds: VariableBase | list[VariableBase],
holded: VariableBase | list[VariableBase],
graph: FunctionGraph,
tracker: Tracker,
):
if not isinstance(holds, list):
holds = [holds]
super().__init__(holds, graph, tracker)
if not isinstance(holded, list):
holded = [holded]
super().__init__(holded, graph, tracker)

def next(self):
raise BreakGraphError(
Expand Down
Loading