Skip to content

Expand support for Collator caller functions #88

Closed
@kmehant

Description

At this point, fms-acceleration patches only torch_call function, however there are standard collators such as DataCollatorForSeq2Seq which do not implement torch_call funciton however use the standard __call__.

see: https://github.com/huggingface/transformers/blob/4d5b45870411053c9c72d24a8e1052e00fe62ad6/src/transformers/data/data_collator.py#L585

We need to update the torch call patch function.

I am happy to raise a PR.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions