Skip to content

Commit

Permalink
small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
jdchang1 committed Nov 12, 2023
1 parent af7fe52 commit e4ff83a
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 27 deletions.
10 changes: 6 additions & 4 deletions cfgs/alg/gail.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ imdb:
label_ix: 1
include_prompt_for_eval: True
is_trainable: True
optimizer_kwargs:
lr: ${alg.optimizer_kwargs.lr}
weight_decay: ${alg.optimizer_kwargs.weight_decay}
eps: 1e-5
optimizer:
id: adamw
args:
lr: ${alg.optimizer_kwargs.lr}
weight_decay: ${alg.optimizer_kwargs.weight_decay}
eps: 1e-5
10 changes: 6 additions & 4 deletions cfgs/alg/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ imdb:
coeff: 0.001
target_kl: 0.1

optimizer_kwargs:
lr: 1e-5
weight_decay: 1e-6
eps: 1e-5
optimizer:
id: adamw
args:
lr: 1e-5
weight_decay: 1e-6
eps: 1e-5

scheduler:
id: linear
Expand Down
2 changes: 1 addition & 1 deletion examples/imdb/imdb_gail.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/bin/bash
accelerate launch --config_file accelerate_cfgs/fsdp_config.yaml --main_process_port 29636 --num_processes 1 main.py task=imdb alg=gail
accelerate launch --config_file accelerate_cfgs/fsdp_config.yaml --num_processes 1 main.py task=imdb alg=gail
28 changes: 13 additions & 15 deletions src/tril/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tril.policies import PolicyRegistry
from tril.utils.builders import build_reward_fn
from tril.utils.helpers import get_optimizer_cls


class Agent(nn.Module):
Expand Down Expand Up @@ -53,18 +54,11 @@ def __init__(
self.setup_models()

# Opimizer
# self.optimizer_cls = self.cfg.alg.get(
# "optimizer_cls", AdamW
# ) # TODO: make optimizer class
self.optimizer_cls = self.cfg.alg.get(
"optimizer_cls", AdamW8bit
) # TODO: make optimizer class
self.optimizer_kwargs = self.cfg.alg.optimizer_kwargs
if self.reward_cfg is not None:
self.reward_optimizer_cls = self.reward_cfg.get(
"optimizer_cls", AdamW
) # TODO: make optimizer class
self.reward_optimizer_kwargs = self.reward_cfg.get("optimizer_kwargs", None)
self.optimizer_cls = get_optimizer_cls(self.cfg.alg.optimizer.id)
if self.reward_cfg is not None and self.reward_cfg.args.get(
"is_trainable", False
):
self.reward_optimizer_cls = get_optimizer_cls(self.reward_cfg.optimizer.id)

def train(self, mode: bool) -> None:
"""Switches model between train-mode and eval-mode.
Expand Down Expand Up @@ -198,20 +192,24 @@ def group_params(params, weight_decay):
return grouped_parameters

def create_fn(params, optim_cls, kwargs):
grouped_params = group_params(params, self.optimizer_kwargs["weight_decay"])
grouped_params = group_params(params, kwargs.weight_decay)
optimizer = optim_cls(grouped_params, **kwargs)
return optimizer

import pdb

pdb.set_trace()

policy_optimizer = create_fn(
self.policy_named_params, self.optimizer_cls, self.optimizer_kwargs
self.policy_named_params, self.optimizer_cls, self.cfg.alg.optimizer.args
)
if not self.cfg.alg.build_reward or not self.reward.is_trainable:
return policy_optimizer

reward_optimizer = create_fn(
self.reward_named_params,
self.reward_optimizer_cls,
self.reward_optimizer_kwargs,
self.reward_cfg.optimizer.args,
)
return policy_optimizer, reward_optimizer

Expand Down
24 changes: 21 additions & 3 deletions src/tril/algorithms/base_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def __init__(
super().__init__(cfg=cfg, accelerator=accelerator, tracker=tracker)

def _setup(self):
# Check config values
sampling_check = self.trajectories_per_update % (
self.sampling_cfg.batch_size_per_process * self.num_processes
)
if sampling_check != 0:
raise ValueError(
"`trajectories_per_update` needs to be divisible by `batch_size_per_process` * `num_processes` for proper distributed gpu training. Please edit these values"
) # noqa
batch_check = self.batch_size % (
self.grad_accumulation_steps * self.num_processes
)
if batch_check != 0:
raise ValueError(
"Set `batch_size` must be achievable with set `grad_accumululation` and `num_processes`. Please edit these values"
) # noqa
minibatch_check = self.trajectories_per_update % self.batch_size
if minibatch_check != 0:
raise ValueError(
"`trajectories_per_update` needs to be divisible by `batch_size` for proper training. Please edit these values"
) # noqa

# Build Components
self.tokenizer = build_tokenizer(self.tokenizer_cfg)
self.agent = Agent(
Expand All @@ -69,9 +90,6 @@ def _setup(self):

self.metrics = build_metrics(self.cfg.get("eval_metrics", []), self.accelerator)
self.samples_by_split = build_task(self.task_cfg)
import pdb

pdb.set_trace()

if hasattr(self.agent.reward, "_spice_metric"):
assert self.agent.reward is not None
Expand Down
17 changes: 17 additions & 0 deletions src/tril/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
import torch
from bitsandbytes.optim import Adam8bit, AdamW8bit
from torch.optim import Adam, AdamW
from tqdm import tqdm

from tril.base_task import Sample
Expand Down Expand Up @@ -71,6 +73,21 @@ def func(_):
return value_schedule


def get_optimizer_cls(optimizer_id: str):
try:
optim_cls = {
"adam": Adam,
"adamw": AdamW,
"adam8bit": Adam8bit,
"adamw8bit": AdamW8bit,
}.get(optimizer_id)
except Exception:
raise ValueError(
f"{optimizer_id} is currently not supported. Please add to tril.utils.helpers."
) # noqa
return optim_cls


def set_global_logging_level(level=logging.ERROR, prefices=[""]):
"""
Override logging levels of different modules based on their name as a prefix.
Expand Down

0 comments on commit e4ff83a

Please sign in to comment.