Skip to content

Commit

Permalink
move Tensor require_global_access to NameCtx
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 16, 2023
1 parent 98b55e8 commit 2064a32
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
5 changes: 1 addition & 4 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ class Tensor:
or via :func:`get_extern_data` for external data.
"""

require_global_access = False

def __init__(
self,
*,
Expand Down Expand Up @@ -505,8 +503,6 @@ class Parameter(Tensor):
wrapping to ``VariableLayer`` in RETURNN.
"""

require_global_access = True

def __init__(
self,
shape: Sequence[Dim],
Expand Down Expand Up @@ -538,6 +534,7 @@ def __init__(
# The name_ctx object will be completed by this information later.
# See Tensor.get_name_in_ctx().
name_ctx = nn.NameCtx(name="<unnamed-param>", parent=None)
name_ctx.require_global_access = True
data = Data("parameter", dim_tags=list(shape), dtype=dtype)
layer_dict = {"class": "variable", "shape": list(shape), "param_name": "param"}
if dtype is not None:
Expand Down
3 changes: 2 additions & 1 deletion nn/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
self._subnet_main_output = None # type: Optional[nn.Tensor] # when this is via SubnetworkLayer
self.virtual = virtual # does not consume a layer name in RETURNN. see get_name_in_ctx
self.can_access_children = can_access_children # from outside
self.require_global_access = False # from outside
self.new_control_flow_ctx = new_control_flow_ctx
self.children = {} # type: Dict[str, NameCtx]
self.extern_data = {} # type: Dict[str, nn.Data] # only for the root name ctx
Expand Down Expand Up @@ -441,7 +442,7 @@ def _auto_assign_parent(self, *, root: nn.NameCtx):
if parent_module_calls:
parent_name_ctx = parent_module_calls[0]
sub_name = attr
if self.tensor.require_global_access and not parent_name_ctx.can_access_children_from_root:
if self.require_global_access and not parent_name_ctx.can_access_children_from_root:
sub_name = parent_name_ctx.name + "_" + sub_name
while not parent_name_ctx.can_access_children_from_root:
parent_name_ctx = parent_name_ctx.parent
Expand Down

0 comments on commit 2064a32

Please sign in to comment.