Skip to content

Support batch-level transformations in Encodings #251

Open
@lorenzoh

Description

@lorenzoh

Sometimes encodings need to be able to take into account batch information, as in a sequence learning task where samples in a batch should be padded to the length of the longest sequence.

Currently, all Encodings transform individual samples, which is great for simplicity and composability, but doesn't allow implementing these batch-level transformations.

A usage of encodings in basically every training loop is taskdataloaders which will always give batches of encoded data. We could have this use a new function encodebatch(encoding, context, block, samples) that transforms multiple samples at a time. This would operate on vectors of samples, not a collated batch, since not all kinds of data can be collated (e.g. different-sized images).

By default, it would simply delegate to the single-sample encode function:

function encodebatch(encoding, context, block, observations::AbstractVector)
    map(obs -> encode(encoding, context, block, obs), observations)
end

But it could be overwritten by individual encodings:

function encodebatch(encoding::PadSequences, context, block, observations::AbstractVector)
    # dummy padding code
    n  = maximum(length, observations)
    return map(obs, pad(obs, n), observations)
end

Tagging relevant parties @Chandu-4444 @darsnack @ToucheSir for discussion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    api-proposalImplementation or suggestion for new APIs and improvements to existing APIs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions