Skip to content

Commit

Permalink
optimizer improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jdchang1 committed Nov 13, 2023
1 parent 569f7c0 commit bb8adfd
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 24 deletions.
6 changes: 5 additions & 1 deletion accelerate_cfgs/zero_stage_2_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"round_robin_gradients": true
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
Expand Down
34 changes: 20 additions & 14 deletions cfgs/alg/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ commongen:
coeff: 0.0
target_kl: 0.1

optimizer_kwargs:
lr: 5e-6
weight_decay: 1e-6
eps: 1e-5
optimizer:
id: adamw
args:
lr: 5e-6
weight_decay: 1e-6
eps: 1e-5

scheduler:
id: linear
Expand Down Expand Up @@ -144,12 +146,13 @@ tldr:
seed: 0
verbose: 0
n_iters: 1000
#batch_size: 32
batch_size: 36
#grad_accumulation: 4
grad_accumulation: 3
#trajectories_per_update: 128
trajectories_per_update: 144
batch_size: 32
#batch_size: 36
grad_accumulation: 4
#grad_accumulation: 2
#grad_accumulation: 3
trajectories_per_update: 128
#trajectories_per_update: 144
n_epochs: 4
gamma: 1.0
gae_lambda: 0.95
Expand All @@ -176,10 +179,13 @@ tldr:
coeff: 0.002
target_kl: 0.1

optimizer_kwargs:
lr: 5e-6
weight_decay: 1e-6
eps: 1e-5
optimizer:
id: adamw8bit
#id: adamw
args:
lr: 5e-6
weight_decay: 1e-6
eps: 1e-5

scheduler:
id: constant
Expand Down
4 changes: 2 additions & 2 deletions cfgs/task/tldr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ task:
max_prompt_length: 500

sampling:
#batch_size_per_process: 32
batch_size_per_process: 36
batch_size_per_process: 32
#batch_size_per_process: 36
max_prompt_len: 500
max_gen_len: 50
prompt_padding_side: left
Expand Down
3 changes: 2 additions & 1 deletion examples/tldr/tldr_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#!/bin/bash
accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 1 main.py task=tldr alg=ppo
accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 4 main.py task=tldr alg=ppo
#accelerate launch --config_file accelerate_cfgs/deepspeed3_config.yaml --num_processes 4 main.py task=tldr alg=ppo
4 changes: 0 additions & 4 deletions src/tril/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,6 @@ def create_fn(params, optim_cls, kwargs):
optimizer = optim_cls(grouped_params, **kwargs)
return optimizer

import pdb

pdb.set_trace()

policy_optimizer = create_fn(
self.policy_named_params, self.optimizer_cls, self.cfg.alg.optimizer.args
)
Expand Down
4 changes: 4 additions & 0 deletions src/tril/algorithms/base_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def learn(self):
print_memory(self.accelerator, tracemalloc, "sampling")

# =========== Train ===========
gc.collect()
torch.cuda.empty_cache()
gc.collect()

with TorchTracemalloc() as tracemalloc:
self.train_step()
if self.verbose > 0:
Expand Down
3 changes: 2 additions & 1 deletion src/tril/policies/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from huggingface_hub import PyTorchModelHubMixin
from numpy.random import Generator
from omegaconf import DictConfig, OmegaConf
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.distributions import Categorical
from transformers import BitsAndBytesConfig, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
)
self.model.__class__ = override_generation_routines(type(self.model))
if self.peft_config is not None:
self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False) # TODO: flag for gradient checkpointing
self.model = get_peft_model(
self.model, self.peft_config, self.policy_adapter_name
)
Expand Down
3 changes: 2 additions & 1 deletion src/tril/rewards/model_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def compute_reward(
ref_ids=None,
ref_mask=None,
retokenize=True,
scale_by_ref=False,
):
self.model.set_adapter(self.rm_adapter_name)
self.model.eval()
Expand Down Expand Up @@ -261,7 +262,7 @@ def compute_reward(
rewards = rewards.cpu()

# Ref norm
if ref_ids is not None:
if ref_ids is not None and scale_by_ref:
if retokenize:
# Retokenize:
samples = tokenizer.batch_decode(ref_ids, skip_special_tokens=True)
Expand Down

0 comments on commit bb8adfd

Please sign in to comment.