Skip to content

Commit

Permalink
Tensor get_dependencies prepare to be moved
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 28, 2023
1 parent cb3b1a8 commit cc3cdf2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 16 deletions.
18 changes: 10 additions & 8 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def __init__(
# We can still try to look at dependencies and use those batch info.
batches = []
for dep in self.get_dependencies(_extra_layer_dict=layer_dict):
if dep.data.batch and dep.data.batch not in batches:
batches.append(dep.data.batch)
if dep.tensor is not None and dep.tensor.data.batch and dep.tensor.data.batch not in batches:
batches.append(dep.tensor.data.batch)
if batches:
data.batch = nn.BatchInfo.get_common_batch_info(batches)
elif name_ctx.root.global_batch:
Expand Down Expand Up @@ -417,22 +417,24 @@ def mark_as_default_output(self) -> Tensor:
res.mark_as_output()
return res

def get_dependencies(self, *, _extra_layer_dict=None) -> List[nn.Tensor]:
def get_dependencies(self, *, _extra_layer_dict=None) -> List[nn.NameCtx]:
"""
:return: list of tensors this tensor depends on
"""
dep_list = []
dep_name_set = set()

def _maybe_add_dep(x):
if isinstance(x, nn.Tensor):
if x.raw_tensor in dep_name_set:
if isinstance(x, nn.NameCtx):
if x in dep_name_set:
return
dep_list.append(x)
dep_name_set.add(x.raw_tensor)
dep_name_set.add(x)
return
if isinstance(x, nn.Tensor):
return _maybe_add_dep(x.raw_tensor)
if isinstance(x, nn.Net):
_maybe_add_dep(x.name_ctx.children["output"].tensor)
return _maybe_add_dep(x.name_ctx.children["output"].tensor)

if _extra_layer_dict:
nest.map_structure(_maybe_add_dep, _extra_layer_dict)
Expand Down Expand Up @@ -691,7 +693,7 @@ def initial(self, value: Optional[nn.init.ParamInitType]):
# However, it's not clear whether this is always safe.
for dep in value.get_dependencies():
assert (
dep.raw_tensor.parent.can_access_children_from_root
dep.parent.can_access_children_from_root
), f"dep {dep} of moved value {value} is not accessible"
self.raw_tensor.layer_dict["init_by_layer"] = value
else:
Expand Down
4 changes: 3 additions & 1 deletion nn/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,7 @@ def __call__(self):
results.append(res)
else:
results.append(_get_sub_layer(res, name, data=true_v.data.copy_template()))
results[-1].raw_tensor.layer_extra_dependencies.extend((self.cond.condition, true_v, false_v))
results[-1].raw_tensor.layer_extra_dependencies.extend(
(self.cond.condition.raw_tensor, true_v.raw_tensor, false_v.raw_tensor)
)
return nest.pack_sequence_as(true_value, results)
8 changes: 4 additions & 4 deletions nn/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_type:
res = self.layer_module() # create the rec layer itself
if self.end_ref is not None:
res.raw_tensor.layer_extra_dependencies.append(self.end_ref)
res.raw_tensor.layer_extra_dependencies.append(self.end_ref.raw_tensor)

@property
def has_entered_scope(self) -> bool:
Expand Down Expand Up @@ -217,7 +217,7 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor:
res.raw_tensor.tensor_remove_unused_cleanup_hooks.append(
lambda _: source.raw_tensor.layer_dict.pop("need_last")
)
res.raw_tensor.layer_extra_dependencies.append(source)
res.raw_tensor.layer_extra_dependencies.append(source.raw_tensor)
self._last_frames[source.raw_tensor] = res
return res

Expand Down Expand Up @@ -327,10 +327,10 @@ def __init__(self, *, name_ctx: nn.NameCtx, cur_layer_name_ctx: nn.NameCtx, data
super().__init__(name_ctx=name_ctx, data=data, is_ref=True)
self.cur_layer_name_ctx = cur_layer_name_ctx

def get_dependencies(self, **kwargs) -> List[nn.Tensor]:
def get_dependencies(self, **kwargs) -> List[nn.NameCtx]:
"""dependencies"""
# Need to overwrite this because self.cur_layer_name_ctx.tensor is only available later.
return super(PrevTensorRef, self).get_dependencies(**kwargs) + [self.cur_layer_name_ctx.tensor]
return super(PrevTensorRef, self).get_dependencies(**kwargs) + [self.cur_layer_name_ctx]

def assign_new_cur_tensor_name_ctx(self, cur_tensor_name_ctx: nn.NameCtx):
"""
Expand Down
6 changes: 3 additions & 3 deletions nn/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self.tensor_parent_modules = [] # type: List[Tuple[nn.Module, str]] # via parent module attrib
self.tensor_remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]]
self.layer_dict = None # type: Optional[nn.LayerDictRaw]
self.layer_extra_dependencies = [] # type: List[nn.Tensor]
self.layer_extra_dependencies = [] # type: List[nn.NameCtx]
self.debug_layer = None # type: Optional[nn.LayerBase]
self._enter_stack_frames = None # type: Optional[Set[types.FrameType]]
self.is_subnet = False # it says whether it can have children
Expand Down Expand Up @@ -374,8 +374,8 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self):
used_names.add(tensor.raw_tensor)
src_ = src + [tensor]
for dep in tensor.get_dependencies():
if dep.raw_tensor not in used_names:
queue.append((dep, src_))
if dep.tensor is not None and dep not in used_names:
queue.append((dep.tensor, src_))

# Parameters usually have no parent assigned at creation time.
if not tensor.raw_tensor.parent and tensor.raw_tensor != root:
Expand Down

0 comments on commit cc3cdf2

Please sign in to comment.