Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3715,14 +3715,15 @@ class RepeatLayer(_ConcatInputLayer):
"""
layer_class = "repeat"

def __init__(self, repetitions, axis="T", **kwargs):
def __init__(self, repetitions, axis="T", out_dim=None, **kwargs):
"""
:param LayerBase|int repetitions:
number of repetitions for each sequence and position in target axis.
Can be [B,T] or [T,B] or some subset of that shape
:param str axis: (dynamic) axis for repetition (currently only time axis is supported)
:param DimensionTag|str axis: (dynamic) axis for repetition (currently only time axis is supported)
:param DimensionTag|None out_dim:
"""
super(RepeatLayer, self).__init__(**kwargs)
super(RepeatLayer, self).__init__(out_dim=out_dim, **kwargs)
self.repetitions = repetitions
if isinstance(self.repetitions, int):
repetitions_data = Data.from_tensor(tf.constant(self.repetitions))
Expand Down Expand Up @@ -3805,7 +3806,7 @@ def copy_placeholder_with_batch_axis(data, other_batch):
# set size placeholders
output_axis = self.output.get_axis_from_description(axis)
tag = self.output.dim_tags[output_axis]
if tag.dimension is None: # dynamic? dyn sizes needed?
if tag.dimension is None and tag.dyn_size is None: # dynamic? dyn sizes needed?
tag.set_tag_on_size_tensor(target_seq_len, batch=self.output.batch)

def get_dep_layers(self):
Expand All @@ -3829,12 +3830,13 @@ def transform_config_dict(cls, d, network, get_layer):
d["repetitions"] = get_layer(d["repetitions"])

@classmethod
def get_out_data_from_opts(cls, name, axis, repetitions, sources=(), **kwargs):
def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None, **kwargs):
"""
:param str name:
:param str axis:
:param LayerBase|int repetitions:
:param list[LayerBase] sources:
:param DimensionTag|str axis:
:param LayerBase|int repetitions:
:param DimensionTag|None out_dim:
:rtype: Data
"""
from ..util.data import DimensionTag
Expand All @@ -3850,8 +3852,11 @@ def get_out_data_from_opts(cls, name, axis, repetitions, sources=(), **kwargs):
else:
new_dim = None
data = data.copy_move_axis(original_axis, data.get_batch_axis(0))
tag = DimensionTag(description="repeated:%s" % name, kind=tag.kind, dimension=new_dim)
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=tag)
if not out_dim:
out_dim = DimensionTag(description="repeated:%s" % name, kind=tag.kind, dimension=new_dim, derived_from_tag=tag)
else:
assert out_dim.dimension == new_dim
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)


class TileLayer(_ConcatInputLayer):
Expand Down