Skip to content

Commit

Permalink
Dim tags, reuse instances, avoid undefined
Browse files Browse the repository at this point in the history
Related: #975
  • Loading branch information
albertz committed Nov 4, 2022
1 parent 9764482 commit 857ad6d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,26 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False):
dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data
if not dyn_size_ext and allow_none and not same_base.derived_from_op:
return None
ctx = dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx
if (same_base.batch or batch) == batch and (same_base.control_flow_ctx or ctx) == ctx:
# The same_base instance is either undefined (no batch, no ctx) or it is defined for the same batch and ctx.
# In any case, reuse it then.
same_base.batch = batch
same_base.control_flow_ctx = ctx
if dyn_size_ext:
same_base.dyn_size_ext = dyn_size_ext
same_base.complete_dyn_size(template_only=True)
return same_base
dim_tag = Dim(
kind=self.kind, description=self.description, dimension=self.dimension,
auto_generated=self.auto_generated,
batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx,
batch=batch, control_flow_ctx=ctx,
dyn_size_ext=dyn_size_ext)
dim_tag.same_as = same_base
if dyn_size_ext and dyn_size_ext.placeholder is not None:
if Dim.get_tag_from_size_tensor(dyn_size_ext.placeholder) is None:
dim_tag.set_tag_on_size_tensor(dyn_size_ext.placeholder, batch=batch)
same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag
same_base._same_for_batch_ctx[(batch, ctx)] = dim_tag
dim_tag.complete_dyn_size(template_only=True)
return dim_tag

Expand Down

0 comments on commit 857ad6d

Please sign in to comment.