1
+ from transformers import Trainer , PreTrainedModel
2
+ from rlhf .data .data import RewardDataCollatorWithPadding
3
+ from torch import nn
4
+ from rlhf .optimizer .lion import DecoupledLionW
5
+ from transformers .trainer_pt_utils import get_parameter_names
6
+ import torch
7
+ from transformers .pytorch_utils import ALL_LAYERNORM_LAYERS
8
+
9
+
10
+ class RewardTrainer (Trainer ):
11
+ def __init__ (
12
+ self ,
13
+ model : PreTrainedModel ,
14
+ args ,
15
+ data_collator = None ,
16
+ train_dataset = None ,
17
+ eval_dataset = None ,
18
+ tokenizer = None ,
19
+ model_init = None ,
20
+ compute_metrics = None ,
21
+ callbacks = None ,
22
+ optimizers = (None , None ),
23
+ preprocess_logits_for_metrics = None ,
24
+ max_length : int = 512 ,
25
+ use_lion : bool = False
26
+ ):
27
+ # data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
28
+ super ().__init__ (
29
+ model = model ,
30
+ args = args ,
31
+ data_collator = data_collator ,
32
+ train_dataset = train_dataset ,
33
+ eval_dataset = eval_dataset ,
34
+ tokenizer = tokenizer ,
35
+ model_init = model_init ,
36
+ compute_metrics = compute_metrics ,
37
+ callbacks = callbacks ,
38
+ optimizers = optimizers ,
39
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics ,
40
+ )
41
+ self .use_lion = use_lion
42
+
43
+ def compute_loss (self , model , inputs ):
44
+ rewards_chosen = model (input_ids = inputs ["input_ids_chosen" ], attention_mask = inputs ["attention_mask_chosen" ])
45
+ rewards_rejected = model (
46
+ input_ids = inputs ["input_ids_rejected" ], attention_mask = inputs ["attention_mask_rejected" ]
47
+ )
48
+ loss = - nn .functional .logsigmoid (rewards_chosen - rewards_rejected ).mean ()
49
+ return loss
50
+
51
+
52
+ def create_optimizer (self ):
53
+ """
54
+ Setup the optimizer.
55
+
56
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
57
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
58
+ """
59
+ opt_model = self .model
60
+
61
+ decay_parameters = get_parameter_names (opt_model , ALL_LAYERNORM_LAYERS )
62
+ decay_parameters = [name for name in decay_parameters if "bias" not in name ]
63
+ optimizer_grouped_parameters = [
64
+ {
65
+ "params" : [
66
+ p for n , p in opt_model .named_parameters () if (n in decay_parameters and p .requires_grad )
67
+ ],
68
+ "weight_decay" : self .args .weight_decay ,
69
+ },
70
+ {
71
+ "params" : [
72
+ p for n , p in opt_model .named_parameters () if (n not in decay_parameters and p .requires_grad )
73
+ ],
74
+ "weight_decay" : 0.0 ,
75
+ },
76
+ ]
77
+
78
+ # lion or adam
79
+ if not self .use_lion :
80
+ optimizer_cls = torch .optim .AdamW
81
+ optimizer_kwargs = {
82
+ "lr" : self .args .learning_rate ,
83
+ "betas" : (self .args .adam_beta1 , self .args .adam_beta2 ),
84
+ "eps" : self .args .adam_epsilon ,
85
+ }
86
+ else :
87
+ optimizer_cls = DecoupledLionW
88
+ optimizer_kwargs = {
89
+ "lr" : self .args .learning_rate ,
90
+ "betas" : (self .args .adam_beta1 , self .args .adam_beta2 ),
91
+ }
92
+
93
+
94
+ self .optimizer = optimizer_cls (optimizer_grouped_parameters , ** optimizer_kwargs )
95
+
96
+
97
+ # print(f"Using optimizer {self.optimizer}")
98
+
99
+ return self .optimizer
100
+
0 commit comments