Skip to content

How to flatten (pack) sequence before softmax + loss or just loss #68

@albertz

Description

@albertz

See also the generic discussion on losses: #38.
However, this issue here is somewhat orthogonal.

Consider the case we wrapped sparse_softmax_cross_entropy_with_logits somehow. This can be applied on inputs [B,T,D] and targets [B,T], or we can flatten (pack) them (removing the padded frames) and then operate on [B',D] and targets [B'].

The packed case should be more efficient because we don't calculate the potential costly softmax (or log softmax) for the padded frames.

How would we do this here? Following from #38, the straightforward thing for the user would be to just call sparse_softmax_cross_entropy_with_logits on the outputs + targets, but that would be on [B,T]. Some options:

  • The user explicitly would need to think about this, and call flatten_batch before explicitly in order to make use of this.
    This is maybe unnatural, and more complicated than it needs to be?
    The flatten_batch logic would need to be extended maybe. Because in the end, RETURNN needs to know about the original sequence lengths to be able to do the frame accumulation correctly. Although, maybe it already has the information in RETURNN BatchInfo?
  • Support flattened (packed, ragged) tensors in a more direct way here in returnn-common. Maybe even making it a potential optimization like reordering axes which layers can do, and other layers could undo when needed.
    How? This is not really clear, and maybe a really major undertaking.
    Also, how does this help here? Would the loss function automatically convert it into a packed tensor? But why only the loss function, and not others, like e.g. Linear? There are many open questions here.
  • Support flattened (packed, ragged) tensors in a more direct way in RETURNN core.
    This could also be some major undertaking. It is partly supported already, but only very little.
    Also, this further leaves open questions on how this is actually used then.
  • RETURNN could do an optimization, that for all losses, it goes backwards through the dependencies to search for the last sequence-level op (layers with recurrent=True), and apply flattening on that output and then repeat all further layers. So that would even include e.g. a previous LinearLayer.
    It could be a bit tricky on some strange edge cases but in general this should be doable without too much effort.
  • RETURNN-common can do a similar optimization. Maybe it's slightly simpler on this level. We also don't really need to support all cases but it should be at least as efficient as what we had before. It would go backwards from the losses.

So far, the last option sounds the most reasonable one to me. But I'm not really sure.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions