@@ -8531,3 +8531,173 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
85318531 kind = DimensionTag .Types .Spatial , description = "%s_rel_pos_enc_time" % name , dimension = None )
85328532 data = data .copy_template_new_dim_tags ((dummy_dim_tag , time_dim_tag , feature_dim_tag ))
85338533 return data
8534+
8535+
8536+ class CumConcatLayer (_ConcatInputLayer ):
8537+ """
8538+ Concatenates all previous frames of a time-axis.
8539+ Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8540+
8541+ This layer can be used as a base for auto-regressive self-attention.
8542+
8543+ This layer expects to be inside a :class:`RecLayer`.
8544+
8545+ Inside a rec loop (not optimized out),
8546+ this will concatenate the current input
8547+ to the previous accumulated inputs.
8548+ For an input of shape `input_shape`,
8549+ it will output a tensor of shape `[new_dim] + input_shape`.
8550+ `new_dim` is a special dimension, usually of length `i`,
8551+ where `i` is the current loop frame,
8552+ i.e. the length increases in every loop frame.
8553+ `new_dim` is specified by a separate own dim tag.
8554+ For example, in the first frame,
8555+ this will be of shape `[1] + input_shape`,
8556+ in the second frame shape `[2] + input_shape`,
8557+ and so on,
8558+ and in the last frame shape `[T] + input_shape`.
8559+
8560+ Outside the rec loop (optimized out),
8561+ this layer expects an input with the time dim of the rec layer,
8562+ and returns the input as-is,
8563+ but replacing the time dim tag with the dim tag `new_dim`
8564+ converted as outside the loop.
8565+
8566+ Normally the optimization should not matter for the user,
8567+ i.e. for the user, the logical behavior is always as being inside the rec loop.
8568+ Outside the loop,
8569+ the output represents a tensor of shape `[T, new_dim] + input_shape`,
8570+ although we actually have another `new_dim` outside the loop,
8571+ and `T` is not actually there,
8572+ but we still have all the information,
8573+ because the last frame has all information.
8574+ This `new_dim` outside the loop stores all the dynamic seq lengths
8575+ per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
8576+ (unlike usually just [B]).
8577+ This way following layers use different seq lengths of `new_dim` for different loop frames,
8578+ just like if the `T` dim would actually exist.
8579+ """
8580+ layer_class = "cum_concat"
8581+ recurrent = True # order matters
8582+
8583+ def __init__ (self , new_dim , ** kwargs ):
8584+ """
8585+ :param DimensionTag new_dim:
8586+ """
8587+ super (CumConcatLayer , self ).__init__ (** kwargs )
8588+ rec_layer = self .network .get_rec_parent_layer (inside_loop = False )
8589+ assert rec_layer , "%r must be used inside a RecLayer" % self
8590+ out_axis = self .output .get_axis_from_description (new_dim )
8591+ new_dim_ = self .output .dim_tags [out_axis ]
8592+ assert new_dim_ .control_flow_ctx == self .output .control_flow_ctx == self .network .get_control_flow_ctx ()
8593+
8594+ if not self .input_data .has_axis (rec_layer .time_dim_tag ): # inside loop
8595+ current_data = self .input_data .copy_compatible_to (self .output , unbroadcast = False )
8596+ current_frame = current_data .placeholder # [B, 1, ..., D]
8597+ last_frames = self ._rec_previous_layer .rec_vars_outputs ["state" ] # [B, t, ..., D]
8598+ concat_frames = tf .concat ([last_frames , current_frame ], axis = out_axis ) # [B, t+1, ..., D]
8599+ self .rec_vars_outputs ["state" ] = concat_frames
8600+ self .output .placeholder = concat_frames
8601+
8602+ if not new_dim_ .dyn_size_ext :
8603+ # Unbroadcasting to [B] is not needed because any layers operating on this
8604+ # should be able to handle extended dyn sizes.
8605+ # Clipping it to the max length for sequences in the loop which are already ended
8606+ # (i.e. considering the end flag)
8607+ # is also not needed because any calculations after the end are irrelevant.
8608+ # Note: In case we have some initial state/output, this can be extended.
8609+ dyn_size = self .network .get_rec_step_index () + 1 # scalar
8610+ new_dim_ .dyn_size_ext = Data (
8611+ name = "%s:cum-concat:size-inside" % self .name ,
8612+ dim_tags = [], # scalar
8613+ placeholder = dyn_size , dtype = "int32" ,
8614+ batch = self .output .batch , control_flow_ctx = self .network .get_control_flow_ctx ())
8615+
8616+ else : # outside loop
8617+ # If not inside a rec loop, this layer is a no-op on the tensor.
8618+ self .output .placeholder = self .input_data .placeholder
8619+
8620+ # However, we used new dim tags, which were already prepared.
8621+ # We now must fill in the extended dynamic size information.
8622+ if not new_dim_ .dyn_size_ext :
8623+ # This must match the logic above for inside the loop.
8624+ # Note: In case we have some initial state/output, this can be extended.
8625+ dyn_size = tf .range (tf .math .reduce_max (rec_layer .time_dim_tag .dyn_size )) + 1 # [T]
8626+ new_dim_ .dyn_size_ext = Data (
8627+ name = "%s:cum-concat:size-outside" % self .name ,
8628+ dim_tags = [rec_layer .time_dim_tag ],
8629+ placeholder = dyn_size , dtype = "int32" ,
8630+ batch = self .output .batch , control_flow_ctx = self .network .get_control_flow_ctx ())
8631+
8632+ @classmethod
8633+ def get_out_data_from_opts (cls , name , network , sources , new_dim , ** kwargs ):
8634+ """
8635+ :param str name:
8636+ :param returnn.tf.network.TFNetwork network:
8637+ :param list[LayerBase] sources:
8638+ :param DimensionTag new_dim:
8639+ :rtype: Data
8640+ """
8641+ input_data = get_concat_sources_data_template (sources , name = "%s_output" % name )
8642+ assert network .is_inside_rec_layer (inside_loop = False ), "CumConcatLayer %r must be used inside a RecLayer" % name
8643+ rec_time_dim = network .get_inside_rec_time_dim (inside_loop = False )
8644+ assert rec_time_dim
8645+ ctx = network .get_control_flow_ctx ()
8646+ assert ctx == input_data .control_flow_ctx
8647+ new_dim_in_ctx = new_dim .get_for_batch_ctx (batch = input_data .batch , ctx = ctx )
8648+
8649+ if not input_data .has_axis (rec_time_dim ): # inside loop
8650+ assert ctx and ctx .is_loop () and ctx .loop_spatial_dim == rec_time_dim
8651+ # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8652+ # Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8653+ # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8654+ # which should be more efficient
8655+ out = input_data .copy_as_batch_major ()
8656+ out = out .copy_add_dim_by_tag (new_dim_in_ctx , unbroadcast = True , axis = 1 )
8657+ return out
8658+
8659+ else : # outside loop
8660+ # Assume that the input has the time dim from the rec layer.
8661+ axis = input_data .get_axis_from_description (rec_time_dim )
8662+ return input_data .copy_template_replace_dim_tag (axis = axis , new_dim_tag = new_dim_in_ctx )
8663+
8664+ # noinspection PyMethodOverriding
8665+ @classmethod
8666+ def get_rec_initial_extra_outputs (cls , network , batch_dim , rec_layer , sources , output , new_dim , ** kwargs ):
8667+ """
8668+ :param returnn.tf.network.TFNetwork network:
8669+ :param tf.Tensor batch_dim:
8670+ :param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
8671+ :param list[LayerBase] sources:
8672+ :param Data output:
8673+ :param DimensionTag new_dim:
8674+ :rtype: dict[str,tf.Tensor]
8675+ """
8676+ if network .is_inside_rec_layer ():
8677+ shape = []
8678+ for tag in output .dim_tags :
8679+ if tag .is_batch_dim ():
8680+ shape .append (batch_dim )
8681+ elif tag == new_dim :
8682+ shape .append (0 )
8683+ elif tag .dimension is not None :
8684+ shape .append (tag .dimension )
8685+ else :
8686+ assert tag .dyn_size is not None
8687+ shape .append (tf .math .reduce_max (tag .dyn_size ))
8688+ return {"state" : tf .zeros (shape , dtype = output .dtype )}
8689+ else :
8690+ return {}
8691+
8692+ @classmethod
8693+ def get_rec_initial_extra_outputs_shape_invariants (cls , network , sources , output , ** kwargs ):
8694+ """
8695+ :param returnn.tf.network.TFNetwork network:
8696+ :param list[LayerBase] sources:
8697+ :param Data output:
8698+ :rtype: dict[str, tf.TensorShape]
8699+ """
8700+ if network .is_inside_rec_layer ():
8701+ return {"state" : tf .TensorShape (output .batch_shape )}
8702+ else :
8703+ return {}
0 commit comments