Skip to content

Commit

Permalink
compatible with trl v0.13 (#2992)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 authored Jan 29, 2025
1 parent 31133df commit d826840
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 15 deletions.
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ tiktoken
tqdm
transformers>=4.33,<4.50
transformers_stream_generator
trl>=0.11,<0.12
trl>=0.13,<=0.14
uvicorn
zstandard
2 changes: 1 addition & 1 deletion swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __post_init__(self):
super().__post_init__()
self._init_ppo()

if self.rlhf_type in ['dpo', 'kto'] and self.train_type == 'full' or self.rlhf_type == 'ppo':
if self.rlhf_type in ['dpo', 'kto', 'ppo'] and self.train_type == 'full':
self.ref_model = self.ref_model or self.model
self.ref_model_type = self.ref_model_type or self.model_type
self.ref_model_revision = self.ref_model_revision or self.model_revision
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_trainer_kwargs(self):
for key in ['ref', 'reward', 'value']:
key = f'{key}_model'
model = getattr(self, key, None)
if model:
if model or self.args.rlhf_type == 'ppo':
trainer_kwargs[key] = model
return trainer_kwargs

Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from trl import DPOConfig as HfDPOConfig
from trl import KTOConfig as HfKTOConfig
from trl import ORPOConfig as HfORPOConfig
from trl import PPOv2Config as HfPPOv2Config
from trl import PPOConfig as HfPPOConfig
from trl import RewardConfig as HfRewardConfig

from .arguments import SwiftArgumentsMixin
Expand Down Expand Up @@ -36,5 +36,5 @@ class RewardConfig(SwiftArgumentsMixin, HfRewardConfig):


@dataclass
class PPOConfig(SwiftArgumentsMixin, HfPPOv2Config):
class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
pass
1 change: 0 additions & 1 deletion swift/trainers/rlhf_trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .rlhf_mixin import RLHFTrainerMixin

del HFCPOTrainer.__init__
del HFCPOTrainer.get_batch_samples


class CPOTrainer(RLHFTrainerMixin, SwiftMixin, HFCPOTrainer):
Expand Down
100 changes: 98 additions & 2 deletions swift/trainers/rlhf_trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from peft import PeftModel
from transformers import PreTrainedModel
Expand All @@ -10,7 +11,6 @@
from .rlhf_mixin import RLHFTrainerMixin

del HFDPOTrainer.__init__
del HFDPOTrainer.get_batch_samples


class DPOTrainer(RLHFTrainerMixin, SwiftMixin, HFDPOTrainer):
Expand All @@ -31,4 +31,100 @@ def __init__(self,

self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free
self.use_weighting = False

super().__init__(model, ref_model, *_args, **kwargs)

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch = batch.copy()
num_examples = batch['labels'].shape[0] // 2
labels = batch.pop('labels', None)
if self.is_encoder_decoder:
batch['labels'] = labels

if self.aux_loss_enabled:
batch['output_router_logits'] = True
outputs = model(**batch, use_cache=False)
batch['labels'] = labels
if outputs.logits.shape[1] != labels.shape[1]:
# for llava, the model returns logits for the entire sequence, including the image tokens
# (placed before the text tokens)
outputs.logits = outputs.logits[:, -labels.shape[1]:]
for key in ['input_ids', 'attention_mask', 'labels']:
batch[f'concatenated_{key}'] = batch.pop(key, None)
if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
batch['concatenated_input_ids'] = batch['concatenated_labels']

all_logits = outputs.logits

if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]:
# for llava, the model returns logits for the entire sequence,
# including the image tokens (placed before the text tokens)
seq_len = batch['concatenated_labels'].shape[1]
all_logits = all_logits[:, -seq_len:]

all_logps, size_completion = self.get_batch_logps(
all_logits,
batch['concatenated_labels'],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

output = {}

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

if self.args.rpo_alpha is not None:
labels = batch['concatenated_labels'].clone()
output['nll_loss'] = cross_entropy_loss(all_logits[:num_examples], labels[:num_examples])

if self.loss_type == 'ipo':
all_logps = all_logps / size_completion

output['chosen_logps'] = all_logps[:num_examples]
output['rejected_logps'] = all_logps[num_examples:]
output['mean_chosen_logits'] = all_logits[:num_examples].mean()
output['mean_rejected_logits'] = all_logits[num_examples:].mean()

if self.aux_loss_enabled:
output['aux_loss'] = outputs.aux_loss

return output

@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
if logits.shape[:-1] != labels.shape:
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
'and labels must have the same shape {labels.shape}')
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
else:
labels = labels.clone()

loss_mask = labels != label_pad_token_id

labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
1 change: 0 additions & 1 deletion swift/trainers/rlhf_trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
logger = get_logger()

del HFKTOTrainer.__init__
del HFKTOTrainer.get_batch_samples


class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer):
Expand Down
1 change: 0 additions & 1 deletion swift/trainers/rlhf_trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .rlhf_mixin import RLHFTrainerMixin

del HFORPOTrainer.__init__
del HFORPOTrainer.get_batch_samples


class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer):
Expand Down
10 changes: 5 additions & 5 deletions swift/trainers/rlhf_trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from packaging import version
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from trl import PPOv2Trainer as HFPPOv2Trainer
from trl import PPOTrainer as HFPPOTrainer

from swift.utils import patch_getattr
from ..mixin import SwiftMixin

ppo_trainer_init = HFPPOv2Trainer.__init__
del HFPPOv2Trainer.__init__
ppo_trainer_init = HFPPOTrainer.__init__
del HFPPOTrainer.__init__


class PPOTrainer(SwiftMixin, HFPPOv2Trainer):
class PPOTrainer(SwiftMixin, HFPPOTrainer):

@staticmethod
@contextmanager
Expand All @@ -40,7 +40,7 @@ def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, *
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
}
ppo_trainer_init(
self, config=kwargs['args'], tokenizer=self.tokenizer, policy=model, ref_policy=ref_model, **new_kwargs)
self, config=kwargs['args'], tokenizer=self.tokenizer, model=model, ref_model=ref_model, **new_kwargs)
unwrap_model = self.accelerator.unwrap_model(self.model)
patch_getattr(unwrap_model.__class__, 'policy')

Expand Down

0 comments on commit d826840

Please sign in to comment.