Skip to content

KTOTrainer should work when actual batch size==1 #2554

Open
@starmpcc

Description

if args.per_device_train_batch_size <= 1:
raise ValueError(
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
)

This check was introduced in #2153
However, the KL logits were calculated by unlinking prompt_input_ids and answer_input_ids, which means the KL term is not equivalent to the reward term.
Accordingly, KTOTrainer should work when the actual batch size is 1.

Thank you!

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions