Skip to content

Commit

Permalink
move Tensor get_dependencies to NameCtx
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 28, 2023
1 parent 30a2432 commit 0ca388f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 34 deletions.
35 changes: 2 additions & 33 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
# but we don't do that here.
# We can still try to look at dependencies and use those batch info.
batches = []
for dep in self.get_dependencies(_extra_layer_dict=layer_dict):
for dep in self.raw_tensor.get_tensor_dependencies(_extra_layer_dict=layer_dict):
if dep.tensor is not None and dep.tensor.data.batch and dep.tensor.data.batch not in batches:
batches.append(dep.tensor.data.batch)
if batches:
Expand Down Expand Up @@ -417,37 +417,6 @@ def mark_as_default_output(self) -> Tensor:
res.mark_as_output()
return res

def get_dependencies(self, *, _extra_layer_dict=None) -> List[nn.NameCtx]:
"""
:return: list of tensors this tensor depends on
"""
dep_list = []
dep_name_set = set()

def _maybe_add_dep(x):
if isinstance(x, nn.NameCtx):
if x in dep_name_set:
return
dep_list.append(x)
dep_name_set.add(x)
return
if isinstance(x, nn.Tensor):
return _maybe_add_dep(x.raw_tensor)
if isinstance(x, nn.Net):
return _maybe_add_dep(x.name_ctx.children["output"].tensor)

if _extra_layer_dict:
nest.map_structure(_maybe_add_dep, _extra_layer_dict)
if self.raw_tensor.layer_dict:
nest.map_structure(_maybe_add_dep, self.raw_tensor.layer_dict)
if self.raw_tensor.children and "output" in self.raw_tensor.children:
_maybe_add_dep(self.raw_tensor.children["output"].tensor)
if self.raw_tensor.parent and self.raw_tensor.parent.tensor:
_maybe_add_dep(self.raw_tensor.parent.tensor)
if self.raw_tensor.layer_extra_dependencies:
dep_list.extend(self.raw_tensor.layer_extra_dependencies)
return dep_list

def _replace_by(self, tensor: nn.Tensor):
"""
Replace this tensor by the given tensor.
Expand Down Expand Up @@ -691,7 +660,7 @@ def initial(self, value: Optional[nn.init.ParamInitType]):
value.raw_tensor.assign_parent(accessible_parent)
# We could also maybe move out all the dependencies.
# However, it's not clear whether this is always safe.
for dep in value.get_dependencies():
for dep in value.raw_tensor.get_tensor_dependencies():
assert (
dep.parent.can_access_children_from_root
), f"dep {dep} of moved value {value} is not accessible"
Expand Down
33 changes: 32 additions & 1 deletion nn/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self):
continue
used_names.add(tensor.raw_tensor)
src_ = src + [tensor]
for dep in tensor.get_dependencies():
for dep in tensor.raw_tensor.get_tensor_dependencies():
if dep.tensor is not None and dep not in used_names:
queue.append((dep.tensor, src_))

Expand Down Expand Up @@ -710,6 +710,37 @@ def _get_unique_name(self, suggested_name: Optional[str] = None) -> str:
return name_
i += 1

def get_tensor_dependencies(self, *, _extra_layer_dict=None) -> List[nn.NameCtx]:
"""
:return: list of tensors this tensor depends on
"""
dep_list = []
dep_name_set = set()

def _maybe_add_dep(x):
if isinstance(x, nn.NameCtx):
if x in dep_name_set:
return
dep_list.append(x)
dep_name_set.add(x)
return
if isinstance(x, nn.Tensor):
return _maybe_add_dep(x.raw_tensor)
if isinstance(x, nn.Net):
return _maybe_add_dep(x.name_ctx.children["output"].tensor)

if _extra_layer_dict:
nest.map_structure(_maybe_add_dep, _extra_layer_dict)
if self.layer_dict:
nest.map_structure(_maybe_add_dep, self.layer_dict)
if self.children and "output" in self.children:
_maybe_add_dep(self.children["output"].tensor)
if self.parent and self.parent.tensor:
_maybe_add_dep(self.parent.tensor)
if self.layer_extra_dependencies:
dep_list.extend(self.layer_extra_dependencies)
return dep_list

def make_all_sub_networks_and_optimize(self):
"""
Go up all parents and create subnetworks which are not initialized yet.
Expand Down

0 comments on commit 0ca388f

Please sign in to comment.