KTOTrainer should work when actual batch size==1 #2554
Open
Description
trl/trl/trainer/kto_trainer.py
Lines 662 to 665 in edabe0a
This check was introduced in #2153
However, the KL logits were calculated by unlinking prompt_input_ids
and answer_input_ids
, which means the KL term is not equivalent to the reward term.
Accordingly, KTOTrainer
should work when the actual batch size is 1.
Thank you!