From bb8adfdc5f5a36290185beb061724ec1288c9ce3 Mon Sep 17 00:00:00 2001 From: jdchang1 Date: Mon, 13 Nov 2023 08:57:38 -0500 Subject: [PATCH] optimizer improvements --- accelerate_cfgs/zero_stage_2_config.json | 6 ++++- cfgs/alg/ppo.yaml | 34 ++++++++++++++---------- cfgs/task/tldr.yaml | 4 +-- examples/tldr/tldr_ppo.sh | 3 ++- src/tril/agent.py | 4 --- src/tril/algorithms/base_online.py | 4 +++ src/tril/policies/actor.py | 3 ++- src/tril/rewards/model_rewards.py | 3 ++- 8 files changed, 37 insertions(+), 24 deletions(-) diff --git a/accelerate_cfgs/zero_stage_2_config.json b/accelerate_cfgs/zero_stage_2_config.json index 3858a2e..c057fbb 100644 --- a/accelerate_cfgs/zero_stage_2_config.json +++ b/accelerate_cfgs/zero_stage_2_config.json @@ -14,7 +14,11 @@ }, "allgather_partitions": true, "allgather_bucket_size": 5e8, - "contiguous_gradients": true + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "round_robin_gradients": true }, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", diff --git a/cfgs/alg/ppo.yaml b/cfgs/alg/ppo.yaml index e82e8d4..f72f3e0 100644 --- a/cfgs/alg/ppo.yaml +++ b/cfgs/alg/ppo.yaml @@ -106,10 +106,12 @@ commongen: coeff: 0.0 target_kl: 0.1 - optimizer_kwargs: - lr: 5e-6 - weight_decay: 1e-6 - eps: 1e-5 + optimizer: + id: adamw + args: + lr: 5e-6 + weight_decay: 1e-6 + eps: 1e-5 scheduler: id: linear @@ -144,12 +146,13 @@ tldr: seed: 0 verbose: 0 n_iters: 1000 - #batch_size: 32 - batch_size: 36 - #grad_accumulation: 4 - grad_accumulation: 3 - #trajectories_per_update: 128 - trajectories_per_update: 144 + batch_size: 32 + #batch_size: 36 + grad_accumulation: 4 + #grad_accumulation: 2 + #grad_accumulation: 3 + trajectories_per_update: 128 + #trajectories_per_update: 144 n_epochs: 4 gamma: 1.0 gae_lambda: 0.95 @@ -176,10 +179,13 @@ tldr: coeff: 0.002 target_kl: 0.1 - optimizer_kwargs: - lr: 5e-6 - weight_decay: 1e-6 - eps: 1e-5 + optimizer: + id: adamw8bit + #id: adamw + args: + lr: 5e-6 + weight_decay: 1e-6 + eps: 1e-5 scheduler: id: constant diff --git a/cfgs/task/tldr.yaml b/cfgs/task/tldr.yaml index af80e17..74f8b6c 100644 --- a/cfgs/task/tldr.yaml +++ b/cfgs/task/tldr.yaml @@ -7,8 +7,8 @@ task: max_prompt_length: 500 sampling: - #batch_size_per_process: 32 - batch_size_per_process: 36 + batch_size_per_process: 32 + #batch_size_per_process: 36 max_prompt_len: 500 max_gen_len: 50 prompt_padding_side: left diff --git a/examples/tldr/tldr_ppo.sh b/examples/tldr/tldr_ppo.sh index f21403b..f5687a2 100755 --- a/examples/tldr/tldr_ppo.sh +++ b/examples/tldr/tldr_ppo.sh @@ -1,2 +1,3 @@ #!/bin/bash -accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 1 main.py task=tldr alg=ppo +accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 4 main.py task=tldr alg=ppo +#accelerate launch --config_file accelerate_cfgs/deepspeed3_config.yaml --num_processes 4 main.py task=tldr alg=ppo diff --git a/src/tril/agent.py b/src/tril/agent.py index a3fc882..f5431a2 100644 --- a/src/tril/agent.py +++ b/src/tril/agent.py @@ -195,10 +195,6 @@ def create_fn(params, optim_cls, kwargs): optimizer = optim_cls(grouped_params, **kwargs) return optimizer - import pdb - - pdb.set_trace() - policy_optimizer = create_fn( self.policy_named_params, self.optimizer_cls, self.cfg.alg.optimizer.args ) diff --git a/src/tril/algorithms/base_online.py b/src/tril/algorithms/base_online.py index 956fc59..192a637 100644 --- a/src/tril/algorithms/base_online.py +++ b/src/tril/algorithms/base_online.py @@ -514,6 +514,10 @@ def learn(self): print_memory(self.accelerator, tracemalloc, "sampling") # =========== Train =========== + gc.collect() + torch.cuda.empty_cache() + gc.collect() + with TorchTracemalloc() as tracemalloc: self.train_step() if self.verbose > 0: diff --git a/src/tril/policies/actor.py b/src/tril/policies/actor.py index 658bf92..e66282b 100644 --- a/src/tril/policies/actor.py +++ b/src/tril/policies/actor.py @@ -8,7 +8,7 @@ from huggingface_hub import PyTorchModelHubMixin from numpy.random import Generator from omegaconf import DictConfig, OmegaConf -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch.distributions import Categorical from transformers import BitsAndBytesConfig, PreTrainedTokenizer from transformers.generation.logits_process import LogitsProcessorList @@ -67,6 +67,7 @@ def __init__( ) self.model.__class__ = override_generation_routines(type(self.model)) if self.peft_config is not None: + self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False) # TODO: flag for gradient checkpointing self.model = get_peft_model( self.model, self.peft_config, self.policy_adapter_name ) diff --git a/src/tril/rewards/model_rewards.py b/src/tril/rewards/model_rewards.py index 6ea75a7..1e8b660 100644 --- a/src/tril/rewards/model_rewards.py +++ b/src/tril/rewards/model_rewards.py @@ -207,6 +207,7 @@ def compute_reward( ref_ids=None, ref_mask=None, retokenize=True, + scale_by_ref=False, ): self.model.set_adapter(self.rm_adapter_name) self.model.eval() @@ -261,7 +262,7 @@ def compute_reward( rewards = rewards.cpu() # Ref norm - if ref_ids is not None: + if ref_ids is not None and scale_by_ref: if retokenize: # Retokenize: samples = tokenizer.batch_decode(ref_ids, skip_special_tokens=True)