Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 73 additions & 27 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,72 +2515,118 @@ class WindowLayer(_ConcatInputLayer):
layer_class = "window"
recurrent = True # we must not allow any shuffling in the time-dim or so

def __init__(self, window_size, window_left=None, window_right=None, axis="T", padding="same", stride=1, **kwargs):
def __init__(self, window_size=None, window_dim=None, window_left=None, window_right=None,
axis="T", out_spatial_dim=None, padding="same", stride=1, **kwargs):
"""
:param int window_size:
:param int|None window_size:
:param DimensionTag|None window_dim:
:param int|None window_left:
:param int|None window_right:
:param str axis: see Data.get_axis_from_description()
:param DimensionTag|str axis: see :func:`Data.get_axis_from_description`
:param DimensionTag|None out_spatial_dim:
:param str padding: "same" or "valid"
:param int stride: return only each Nth window
:param kwargs:
"""
super(WindowLayer, self).__init__(**kwargs)
if not window_size:
assert window_dim and window_dim.dimension
window_size = window_dim.dimension
data = self.input_data.copy_as_batch_major()
if axis == "T" and data.time_dim_axis is None:
# Assume inside RecLayer.
axis = None
assert self._rec_previous_layer, "%s: expected to be used inside a RecLayer" % self
assert padding == "same"
assert window_right is not None or window_left is not None, (
"%s: recurrent variant should explicitly specify window_right=0 or window_left=window_size-1" % self)
if window_left is not None:
assert window_size == window_left + 1, "%s: recurrent variant can only have window into the past" % self
if window_right is not None:
assert window_right == 0, "%s: recurrent variant can only have window into the past" % self
prev_state = self._rec_previous_layer.rec_vars_outputs["state"] # (batch,window,...)
next_state = tf.concat(
[prev_state[:, 1:], tf.expand_dims(data.placeholder, axis=1)], axis=1) # (batch,window,...)
self.rec_vars_outputs["state"] = next_state
self.output.placeholder = next_state

else:
axis = data.get_axis_from_description(axis)
new_dim_axis = axis + 1 # new axis will be added right after
in_spatial_dim = data.dim_tags[axis]
out_spatial_dim_ = self.output.dim_tags[axis]
if out_spatial_dim:
assert out_spatial_dim_ == out_spatial_dim
if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim
assert in_spatial_dim == out_spatial_dim
if in_spatial_dim != out_spatial_dim_ and out_spatial_dim_.dimension is None:
if not out_spatial_dim_.dyn_size_ext:
out_spatial_dim_.dyn_size_ext = in_spatial_dim.dyn_size_ext.copy_template(name="%s:spatial-size" % self.name)
if out_spatial_dim_.dyn_size_ext.placeholder is None:
from ..util.basic import same_control_flow_ctx
from ..util.data import DimensionTag
assert in_spatial_dim.dyn_size is not None
size = in_spatial_dim.dyn_size
with same_control_flow_ctx(size):
size = ConvLayer.calc_out_dim(
in_dim=size,
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
out_spatial_dim_.dyn_size_ext.placeholder = size

from returnn.tf.util.basic import windowed_nd
self.output.placeholder = windowed_nd(
data.placeholder,
window_size=window_size, window_left=window_left, window_right=window_right,
padding=padding, time_axis=axis, new_window_axis=axis + 1, stride=stride)
padding=padding, time_axis=axis, new_window_axis=new_dim_axis, stride=stride)
self.output.placeholder.set_shape(tf.TensorShape(self.output.batch_shape))
self.output.size_placeholder = self.input_data.size_placeholder.copy()
if axis is not None:
axis_wo_b = self.output.get_batch_axis_excluding_batch(axis)
if axis_wo_b in self.output.size_placeholder:
size = self.output.size_placeholder[axis_wo_b]
from ..util.basic import same_control_flow_ctx
from ..util.data import DimensionTag
with same_control_flow_ctx(size):
size = ConvLayer.calc_out_dim(
in_dim=size,
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
DimensionTag(
kind=DimensionTag.Types.Spatial, description="%s:window:%i" % (self.name, axis_wo_b),
dimension=None, dyn_size=size, batch=self.output.batch,
src_data=self.output, src_axis=axis)
self.output.size_placeholder[axis_wo_b] = size

@classmethod
def get_out_data_from_opts(cls, name, window_size, axis="T", sources=(), **kwargs):
def get_out_data_from_opts(cls, name, sources, window_size=None, window_dim=None,
axis="T", out_spatial_dim=None, padding="same", stride=1,
**kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param int window_size:
:param str axis:
:param int|None window_size:
:param DimensionTag|None window_dim:
:param DimensionTag|str axis:
:param DimensionTag|None out_spatial_dim:
:param str padding:
:param int stride:
:rtype: Data
"""
if not window_size:
assert window_dim and window_dim.dimension
window_size = window_dim.dimension
data = get_concat_sources_data_template(sources)
data = data.copy_template(name="%s_output" % name)
data = data.copy_as_batch_major()
if axis == "T" and data.time_dim_axis is None:
# Assume inside RecLayer.
axis = 0
assert not out_spatial_dim
new_dim_axis = 1 # after batch
else:
axis = data.get_axis_from_description(axis)
data = data.copy_add_spatial_dim(spatial_dim_axis=axis + 1, dim=window_size) # add new axis right after
return data
in_spatial_dim = data.dim_tags[axis]
if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim
out_spatial_dim = in_spatial_dim # error check in __init__
else: # new spatial dim
if not out_spatial_dim:
dim = None
if in_spatial_dim.dimension is not None:
dim = ConvLayer.calc_out_dim(
in_dim=in_spatial_dim.dimension,
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
out_spatial_dim = DimensionTag(
kind=DimensionTag.Types.Spatial, description="%s:spatial" % name,
dimension=dim, batch=data.batch, control_flow_ctx=data.control_flow_ctx)
data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
new_dim_axis = axis + 1 # add new axis right after
if window_dim:
assert window_dim.dimension == window_size
else:
window_dim = DimensionTag(
kind=DimensionTag.Types.Spatial, description="%s:window" % name, dimension=window_size)
return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True)

# noinspection PyMethodOverriding
@classmethod
Expand Down