Skip to content
200 changes: 147 additions & 53 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self,
:param bool use_global_rec_step_offset:
:param bool include_eos: for search, whether we should include the frame where "end" is True
:param bool|None debug:
:param DimensionTag|None axis:
:param DimensionTag|str|None axis:
"""
super(RecLayer, self).__init__(**kwargs)
import re
Expand All @@ -147,6 +147,17 @@ def __init__(self,
self._initial_state_deps = [layer for layer in nest.flatten(initial_state) if isinstance(layer, LayerBase)]
self._input_projection = input_projection
self._max_seq_len = max_seq_len
if isinstance(axis, str):
assert self.input_data
axis_int = self.input_data.get_axis_from_description(axis)
axis = self.input_data.dim_tags[axis_int]
if axis:
assert isinstance(axis, DimensionTag)
if axis and self.input_data and axis in self.input_data.dim_tags:
axis_int = self.input_data.get_axis_from_description(axis)
self.input_data.time_dim_axis = axis_int # makes some of the following code easier
if self.input_data.time_dim_axis == self.input_data.feature_dim_axis:
self.input_data.feature_dim_axis = NotSpecified
self.time_dim_tag = axis
self.include_eos = include_eos
if optimize_move_layers_out is None:
Expand Down Expand Up @@ -311,6 +322,16 @@ def transform_config_dict(cls, d, network, get_layer):
# We need to figure out the output time dim tag at this early point,
# because _SubnetworkRecCell might need it during template construction.
source_data = get_concat_sources_data_template(d["sources"]) if d["sources"] else None
time_dim_tag_explicit = d.get("axis")
if source_data and isinstance(time_dim_tag_explicit, str):
time_dim_tag_explicit = source_data.get_dim_tag_from_description(time_dim_tag_explicit)
if time_dim_tag_explicit:
assert isinstance(time_dim_tag_explicit, DimensionTag)
if source_data and time_dim_tag_explicit and time_dim_tag_explicit in source_data.dim_tags:
# Make sure it is marked as time dim. This will make it easier in the following.
source_data.time_dim_axis = source_data.get_axis_from_description(time_dim_tag_explicit)
if source_data.time_dim_axis == source_data.feature_dim_axis:
source_data.feature_dim_axis = NotSpecified
have_dyn_seq_len_end = False
if isinstance(d.get("unit"), dict):
have_dyn_seq_len_end = "end" in d["unit"]
Expand Down Expand Up @@ -341,11 +362,10 @@ def transform_config_dict(cls, d, network, get_layer):
time_dim_tag = DimensionTag(
description="dyn-time:%s%s" % (network.get_absolute_name_prefix(), d["_name"]),
kind=DimensionTag.Types.Time)
if d.get("axis") and time_dim_tag:
time_dim_tag_explicit = d["axis"]
assert isinstance(time_dim_tag_explicit, DimensionTag)
if time_dim_tag_explicit and time_dim_tag:
time_dim_tag.declare_same_as(time_dim_tag_explicit)
d["axis"] = time_dim_tag
if not time_dim_tag_explicit:
d["axis"] = time_dim_tag

if isinstance(d.get("unit"), dict):
sub_net_dict = d.pop("unit")
Expand Down Expand Up @@ -386,51 +406,80 @@ def max_len_from(src):
d["max_seq_len"] = eval(d["max_seq_len"], {"max_len_from": max_len_from, "tf": tf})

@classmethod
def get_out_data_from_opts(cls, network, unit, axis=None, sources=(), initial_state=None, **kwargs):
def get_out_data_from_opts(cls, name, network, sources, unit, axis=None, out_dim=None, initial_state=None,
**kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param str|dict[str] unit:
:param DimensionTag|None axis:
:param list[LayerBase] sources:
:param DimensionTag|None out_dim:
:param str|LayerBase|list[str|LayerBase] initial_state:
:rtype: Data
"""
from tensorflow.python.util import nest
source_data = get_concat_sources_data_template(sources) if sources else None
if isinstance(axis, str):
assert source_data
axis_int = source_data.get_axis_from_description(axis)
axis = source_data.dim_tags[axis_int]
if axis:
assert isinstance(axis, DimensionTag)
if source_data and axis in source_data.dim_tags:
# This will make it easier in the following.
source_data.time_dim_axis = source_data.get_axis_from_description(axis)
if source_data.time_dim_axis == source_data.feature_dim_axis:
source_data.feature_dim_axis = NotSpecified
if source_data and source_data.have_time_axis() and not axis:
axis = source_data.get_time_dim_tag()
n_out = kwargs.get("n_out", NotSpecified)
out_type = kwargs.get("out_type", None)
loss = kwargs.get("loss", None)
deps = list(sources) # type: typing.List[LayerBase]
deps += [layer for layer in nest.flatten(initial_state) if isinstance(layer, LayerBase)]
if isinstance(unit, _SubnetworkRecCell): # subnetwork
subnet = unit
out = (
subnet.layer_data_templates["output"].output
.copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0)
.copy_template_set_ctx(network.get_control_flow_ctx()))
if n_out is not NotSpecified:
assert n_out == out.dim
if out_type:
for k, v in out_type.items():
assert getattr(out, k) == v
out = subnet.layer_data_templates["output"].output.copy_template(name="%s_output" % name)
if axis:
out = out.copy_add_dim_by_tag(dim_tag=axis, axis=0, unbroadcast=True)
out.time_dim_axis = 0
out = out.copy_template_set_ctx(network.get_control_flow_ctx())
deps += subnet.get_parent_deps()
elif out_type or n_out is not NotSpecified or loss:
out = super(RecLayer, cls).get_out_data_from_opts(network=network, sources=sources, **kwargs)
if source_data and not source_data.have_time_axis():
elif out_type or n_out is not NotSpecified or out_dim or loss:
assert source_data
out = source_data.copy_template(name="%s_output" % name)
if out.sparse:
out.dtype = "float32"
out.sparse = False
out = out.copy_add_feature_dim() # dummy
if not out_dim:
if n_out is NotSpecified or not n_out:
assert out_type and "dim" in out_type
n_out = out_type["dim"]
out_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="%s:feature" % name, dimension=n_out)
if out.have_feature_axis():
out = out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim)
else:
out = out.copy_add_dim_by_tag(dim_tag=out_dim, unbroadcast=True, axis=-1)
out.feature_dim_axis = NotSpecified
if out.have_time_axis() and axis == out.get_time_dim_tag():
out = out.copy_as_time_batch_major()
else:
# We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer).
out = out.copy_as_batch_major() # The output is then [B,F]
else:
out = out.copy_as_time_batch_major() # Otherwise the output is always [T,B,F]
else:
raise Exception("n_out or out_type must be specified")
if out.have_time_axis() and axis:
out = out.copy_template_replace_dim_tag(axis=out.time_dim_axis, new_dim_tag=axis)
raise Exception("n_out or out_type or out_dim must be specified")
if n_out is not NotSpecified:
assert n_out == out.dim
if out_dim:
assert out_dim == out.feature_dim_or_sparse_dim
if out_type:
for k, v in out_type.items():
assert getattr(out, k) == v
for dep in deps:
if dep:
out.beam = SearchBeam.get_combined_beam(out.beam, dep.output.beam)
if out_type:
assert out_type.get("time_dim_axis", out.time_dim_axis) == out.time_dim_axis
assert out_type.get("batch_dim_axis", out.batch_dim_axis) == out.batch_dim_axis
return out

def get_absolute_name_scope_prefix(self):
Expand All @@ -448,10 +497,17 @@ def get_rec_initial_extra_outputs(cls, **kwargs):
"""
sources = kwargs.get("sources")
source_data = get_concat_sources_data_template(sources) if sources else None
if source_data and not source_data.have_time_axis():
# We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer).
return {"state": RnnCellLayer.get_rec_initial_state(**kwargs)}
return {}
axis = kwargs["axis"]
if isinstance(axis, str):
assert source_data
axis_int = source_data.get_axis_from_description(axis)
axis = source_data.dim_tags[axis_int]
if axis:
assert isinstance(axis, DimensionTag)
if source_data and axis in source_data.dim_tags:
return {}
# We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer).
return {"state": RnnCellLayer.get_rec_initial_state(**kwargs)}

@classmethod
def get_rec_initial_output(cls, **kwargs):
Expand Down Expand Up @@ -558,22 +614,56 @@ def get_rnn_cell_class(cls, name, cell_only=False):
raise Exception("unknown cell %r. known cells: %r" % (name, sorted(cls._rnn_cells_dict.keys())))
return cls._rnn_cells_dict[name.lower()]

def _get_input(self):
def _get_input(self, strict=True):
"""
:return: (x, seq_len), where x is (time,batch,...,dim) and seq_len is (batch,)
:rtype: (tf.Tensor, tf.Tensor)
:param bool strict:
:return: (in_data, x, seq_len), where x is (time,batch,...,dim) and seq_len is (batch,)
:rtype: (Data, tf.Tensor, tf.Tensor)
"""
assert self.input_data
if self.input_data.have_time_axis():
x = self.input_data.copy_as_time_batch_major().placeholder
seq_len = self.input_data.get_sequence_lengths()
return x, seq_len
else: # no time-dim-axis, expect to be inside another RecLayer
# Just add a dummy time dim, and seq_len == 1 everywhere.
x = self.input_data.placeholder
x = tf.expand_dims(x, 0)
seq_len = tf.ones([self.input_data.get_batch_dim()], dtype=self.input_data.size_dtype)
return x, seq_len
in_data = self.input_data
if not in_data.have_time_axis() or not self.time_dim_tag:
in_data = in_data.copy_add_spatial_dim(spatial_dim_axis=0)
in_data.time_dim_axis = 0
if strict:
# Merge all other axes except time and feature such that we get (time,batch',dim).
in_data = in_data.copy_as_time_major()
if not in_data.sparse:
in_data = in_data.copy_with_feature_last()
else:
in_data = in_data.copy_as_time_batch_major()
in_data_ = in_data
if strict and in_data_.batch_ndim_dense != 3:
assert in_data_.batch_ndim_dense > 3
batch_axes = list(range(1, in_data_.batch_ndim_dense - 1))
assert len(batch_axes) >= 2
in_data_ = in_data_.copy_merge_into_batch(batch_axes)
return in_data, in_data_.placeholder, in_data_.get_sequence_lengths()

def _post_proc_output_cell_strict(self, y, in_data):
"""
:param tf.Tensor y: (time,batch,dim)
:param Data in_data:
:rtype: tf.Tensor
:return: (time,batch,dim) or (time,batch,...,dim)
"""
if in_data.batch_ndim_dense > 3:
y_shape = tf_util.get_shape(in_data.placeholder)
if not in_data.sparse:
y_shape = y_shape[:-1]
y_shape += [tf_util.get_shape_dim(y, -1)]
y = tf.reshape(y, y_shape)
out_data = in_data.copy_template_dense()
out_data = out_data.copy_template_replace_dim_tag(
axis=out_data.feature_dim_axis, new_dim_tag=self.output.feature_dim_or_sparse_dim)
out_data.placeholder = y
if not self.input_data.have_time_axis() or not self.time_dim_tag:
out_data = out_data.copy_squeeze_axes(axes=[0])
# The output format should match now.
# If this is not the case, we should fix get_out_data_from_opts accordingly
# and avoid unnecessary further transformations here, esp any transposes.
assert out_data.dim_tags == self.output.dim_tags
return out_data.placeholder

@classmethod
def get_losses(cls, name, network, output, loss=None, reduce_func=None, layer=None, **kwargs):
Expand Down Expand Up @@ -690,16 +780,18 @@ def _get_output_cell(self, cell):
pass
assert self.input_data
assert not self.input_data.sparse
x, seq_len = self._get_input()
in_data, x, seq_len = self._get_input()
if self._direction == -1:
x = tf_compat.v1.reverse_sequence(x, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
if isinstance(cell, BaseRNNCell):
with tf_compat.v1.variable_scope(tf_compat.v1.get_variable_scope(), initializer=self._fwd_weights_initializer):
x = cell.get_input_transformed(x)
if isinstance(cell, rnn_cell.RNNCell): # e.g. BasicLSTMCell
if not self.input_data.have_time_axis():
if not self.input_data.have_time_axis() or not self.time_dim_tag:
assert self._direction in [1, None]
assert self._rec_previous_layer, "%s: assume in loop with input %s, but no rec info" % (self, self.input_data)
y, final_state = cell(self.input_data.placeholder, self._initial_state)
in_data = None # mark that we have not used this
elif self._unroll:
assert self._max_seq_len is not None, "specify max_seq_len for unroll"
# We must get x.shape[0] == self._max_seq_len, so pad it.
Expand Down Expand Up @@ -741,6 +833,8 @@ def _get_output_cell(self, cell):
raise Exception("invalid type: %s" % type(cell))
if self._direction == -1:
y = tf_compat.v1.reverse_sequence(y, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
if in_data:
y = self._post_proc_output_cell_strict(y, in_data=in_data)
return y

@staticmethod
Expand Down Expand Up @@ -826,7 +920,7 @@ def _get_output_cudnn(self, cell):
assert self._max_seq_len is None
assert self.input_data
assert not self.input_data.sparse
x, seq_len = self._get_input()
in_data, x, seq_len = self._get_input()
n_batch = tf.shape(seq_len)[0]
if self._direction == -1:
x = tf_compat.v1.reverse_sequence(x, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
Expand Down Expand Up @@ -872,6 +966,7 @@ def _get_output_cudnn(self, cell):
y, _ = cell(x, initial_state=(input_h, input_c))
if self._direction == -1:
y = tf_compat.v1.reverse_sequence(y, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
y = self._post_proc_output_cell_strict(y, in_data=in_data) # noqa
return y # noqa

def _get_output_native_rec_op(self, cell):
Expand All @@ -884,7 +979,7 @@ def _get_output_native_rec_op(self, cell):

assert self._max_seq_len is None
assert self.input_data
x, seq_len = self._get_input()
in_data, x, seq_len = self._get_input()
if self._input_projection:
if cell.does_input_projection:
# The cell get's x as-is. It will internally does the matrix mult and add the bias.
Expand All @@ -907,7 +1002,7 @@ def _get_output_native_rec_op(self, cell):
assert not cell.does_input_projection
assert not self.input_data.sparse
assert self.input_data.dim == cell.n_input_dim
if self.input_data.have_time_axis():
if self.input_data.have_time_axis() and self.time_dim_tag:
index = sequence_mask_time_major(seq_len, maxlen=self.input_data.time_dimension())
else:
index = tf.ones([1, self.input_data.get_batch_dim()], dtype=tf.bool) # see _get_input
Expand All @@ -921,8 +1016,7 @@ def _get_output_native_rec_op(self, cell):
self._last_hidden_state = final_state
if not cell.does_direction_handling:
y = directed(y, self._direction)
if not self.input_data.have_time_axis(): # see _get_input
y = y[0]
y = self._post_proc_output_cell_strict(y, in_data=in_data)
return y

def _get_output_subnet_unit(self, cell):
Expand Down Expand Up @@ -2162,12 +2256,12 @@ def get_output(self):
if rec_layer.input_data:
with tf.name_scope("source_tensor_array"):
# noinspection PyProtectedMember
source, input_seq_len = rec_layer._get_input() # source will be (time,batch,..,dim)
source_data, source, input_seq_len = rec_layer._get_input(strict=False) # (time,batch,..,dim)
source_shape = tf.shape(source, name="source_shape")
source_ta = tf.TensorArray(
name="source_ta",
dtype=rec_layer.input_data.dtype,
element_shape=tf.TensorShape(rec_layer.input_data.copy_template_excluding_time_dim().batch_shape),
element_shape=tf.TensorShape(source_data.copy_template_excluding_time_dim().batch_shape),
size=source_shape[0],
infer_shape=True)
source_ta = source_ta.unstack(source, name="source_ta_unstack")
Expand Down
Loading