Skip to content

Commit

Permalink
[SOT] Rename self.hold to self.holds in iter.py (#71473)
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenStain authored Mar 7, 2025
1 parent fecaa1c commit 85a81d9
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions python/paddle/jit/sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ class IterVariable(VariableBase):
"""

def __init__(
self, holds: list[VariableBase], graph: FunctionGraph, tracker: Tracker
self, holded: list[VariableBase], graph: FunctionGraph, tracker: Tracker
):

super().__init__(graph, tracker)
self.hold = holds
self.holds = holded

def make_stringified_guard(self):
return [
result
for holds in self.hold
for result in holds.make_stringified_guard()
for holded in self.holds
for result in holded.make_stringified_guard()
]

def next(self):
Expand All @@ -60,11 +60,11 @@ def get_iter(self):
return self

def flatten_inner_vars(self) -> list[VariableBase]:
holds = self.hold
holded = self.holds
return [
inner_var
for hold in holds
for inner_var in hold.flatten_inner_vars()
for obj in holded
for inner_var in obj.flatten_inner_vars()
]


Expand All @@ -84,21 +84,20 @@ class SequenceIterVariable(IterVariable):

def __init__(
self,
holds: VariableBase | list[VariableBase],
holded: VariableBase | list[VariableBase],
graph: FunctionGraph,
tracker: Tracker,
):
if not isinstance(holds, list):
holds = [holds]
super().__init__(holds, graph, tracker)
if not isinstance(holded, list):
holded = [holded]
super().__init__(holded, graph, tracker)
self.idx = 0
self.graph.side_effects.record_mutable_variable(self)

def next(self):
holds = self.hold[0]
# TODO: self.hold should have a __len__ method
if self.idx < len(holds):
val = holds[self.idx]
holded = self.holds[0]
if self.idx < len(holded):
val = holded[self.idx]
self.idx += 1
return val
else:
Expand All @@ -107,11 +106,11 @@ def next(self):
def to_list(self) -> list:
if self.has_side_effect():
raise FallbackError("Can not convert an used iterator into list")
holds = self.hold[0]
self.idx = len(holds)
holded = self.holds[0]
self.idx = len(holded)
retval = []
for i in range(len(holds)):
retval.append(holds[i])
for i in range(len(holded)):
retval.append(holded[i])
return retval

def has_side_effect(self) -> bool:
Expand All @@ -121,7 +120,7 @@ def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
self.hold[0].reconstruct(codegen)
self.holds[0].reconstruct(codegen)
codegen.gen_get_iter()

@property
Expand All @@ -142,30 +141,30 @@ def __init__(
super().__init__(val_iterator, graph, tracker)

def next(self):
val = self.hold[0].next()
val = self.holds[0].next()
idx_var = ConstantVariable(self.idx, self.graph, ConstTracker(self.idx))
self.idx += 1
return TupleVariable(
(idx_var, val), self.graph, DummyTracker([idx_var, val])
)

def to_list(self):
values = self.hold[0].to_list()
values = self.holds[0].to_list()
idx = [
ConstantVariable(i, self.graph, ConstTracker(i))
for i in range(len(values))
]
return list(zip(idx, values))

def has_side_effect(self) -> bool:
return self.hold[0].has_side_effect()
return self.holds[0].has_side_effect()

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("enumerate", push_null=True)
self.hold[0].reconstruct(codegen)
self.holds[0].reconstruct(codegen)
codegen.gen_call_function(1)

@staticmethod
Expand All @@ -191,7 +190,7 @@ def next(self):
# can not use <listcomp> here, because it will raise a RuntimeError("StopIteration")
# but we want a StopIteration Exception
values = []
for iter_var in self.hold:
for iter_var in self.holds:
next_var = iter_var.next()
values.append(next_var)

Expand All @@ -200,30 +199,30 @@ def next(self):
)

def to_list(self):
lists = [iter_vars.to_list() for iter_vars in self.hold]
lists = [iter_vars.to_list() for iter_vars in self.holds]
min_len = min(len(l) for l in lists)
result = []
for i in range(min_len):
result.append(
VariableFactory.from_value(
tuple(l[i] for l in lists),
self.graph,
DummyTracker(list(self.hold)),
DummyTracker(list(self.holds)),
)
)
return result

def has_side_effect(self) -> bool:
return any(iter_var.has_side_effect() for iter_var in self.hold)
return any(iter_var.has_side_effect() for iter_var in self.holds)

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("zip", push_null=True)
for iter_var in self.hold:
for iter_var in self.holds:
iter_var.reconstruct(codegen)
codegen.gen_call_function(len(self.hold))
codegen.gen_call_function(len(self.holds))

@staticmethod
def from_iterator(
Expand Down Expand Up @@ -255,28 +254,28 @@ def __init__(self, fn, iters: list[IterVariable], graph, tracker):

def next(self):

return self.fn(*[iter_var.next() for iter_var in self.hold])
return self.fn(*[iter_var.next() for iter_var in self.holds])

def to_list(self) -> list:
lists = [iter_var.to_list() for iter_var in self.hold]
lists = [iter_var.to_list() for iter_var in self.holds]
min_len = min(len(l) for l in lists)
result = []
for i in range(min_len):
result.append(self.fn(*(l[i] for l in lists)))
return result

def has_side_effect(self) -> bool:
return any(iter_var.has_side_effect() for iter_var in self.hold)
return any(iter_var.has_side_effect() for iter_var in self.holds)

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("map", push_null=True)
self.fn.reconstruct(codegen)
for iter_var in self.hold:
for iter_var in self.holds:
iter_var.reconstruct(codegen)
codegen.gen_call_function(len(self.hold) + 1)
codegen.gen_call_function(len(self.holds) + 1)

@staticmethod
def from_iterator(
Expand All @@ -300,13 +299,13 @@ def from_iterator(
class UserDefinedIterVariable(IterVariable):
def __init__(
self,
holds: VariableBase | list[VariableBase],
holded: VariableBase | list[VariableBase],
graph: FunctionGraph,
tracker: Tracker,
):
if not isinstance(holds, list):
holds = [holds]
super().__init__(holds, graph, tracker)
if not isinstance(holded, list):
holded = [holded]
super().__init__(holded, graph, tracker)

def next(self):
raise BreakGraphError(
Expand Down

0 comments on commit 85a81d9

Please sign in to comment.