diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 1ccb14980..28e5edc77 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -387,16 +387,26 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False): dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data if not dyn_size_ext and allow_none and not same_base.derived_from_op: return None + ctx = dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx + if (same_base.batch or batch) == batch and (same_base.control_flow_ctx or ctx) == ctx: + # The same_base instance is either undefined (no batch, no ctx) or it is defined for the same batch and ctx. + # In any case, reuse it then. + same_base.batch = batch + same_base.control_flow_ctx = ctx + if dyn_size_ext: + same_base.dyn_size_ext = dyn_size_ext + same_base.complete_dyn_size(template_only=True) + return same_base dim_tag = Dim( kind=self.kind, description=self.description, dimension=self.dimension, auto_generated=self.auto_generated, - batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, + batch=batch, control_flow_ctx=ctx, dyn_size_ext=dyn_size_ext) dim_tag.same_as = same_base if dyn_size_ext and dyn_size_ext.placeholder is not None: if Dim.get_tag_from_size_tensor(dyn_size_ext.placeholder) is None: dim_tag.set_tag_on_size_tensor(dyn_size_ext.placeholder, batch=batch) - same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag + same_base._same_for_batch_ctx[(batch, ctx)] = dim_tag dim_tag.complete_dyn_size(template_only=True) return dim_tag