Skip to content

Commit

Permalink
preprocessing caching
Browse files Browse the repository at this point in the history
  • Loading branch information
jdchang1 committed Nov 11, 2023
1 parent 37fcdda commit af7fe52
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 33 deletions.
10 changes: 6 additions & 4 deletions cfgs/alg/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,12 @@ tldr:
seed: 0
verbose: 0
n_iters: 1000
batch_size: 32
grad_accumulation: 4
trajectories_per_update: 128
#batch_size: 32
batch_size: 36
#grad_accumulation: 4
grad_accumulation: 3
#trajectories_per_update: 128
trajectories_per_update: 144
n_epochs: 4
gamma: 1.0
gae_lambda: 0.95
Expand All @@ -159,7 +162,6 @@ tldr:
eval_batch_size: 5
eval_every: 10
save_every: 100
#eval_zero_shot: true
eval_zero_shot: false
save_checkpoints: false
eval_splits: ['val']
Expand Down
3 changes: 2 additions & 1 deletion cfgs/task/tldr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ task:
max_prompt_length: 500

sampling:
batch_size_per_process: 32
#batch_size_per_process: 32
batch_size_per_process: 36
max_prompt_len: 500
max_gen_len: 50
prompt_padding_side: left
Expand Down
2 changes: 1 addition & 1 deletion examples/tldr/tldr_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/bin/bash
accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 4 main.py task=tldr alg=ppo
accelerate launch --config_file accelerate_cfgs/deepspeed_config.yaml --num_processes 1 main.py task=tldr alg=ppo
6 changes: 5 additions & 1 deletion src/tril/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from accelerate import Accelerator
from bitsandbytes.optim import AdamW8bit
from omegaconf import DictConfig, OmegaConf
from peft import LoraConfig
from torch.optim import AdamW, Optimizer
Expand Down Expand Up @@ -52,8 +53,11 @@ def __init__(
self.setup_models()

# Opimizer
# self.optimizer_cls = self.cfg.alg.get(
# "optimizer_cls", AdamW
# ) # TODO: make optimizer class
self.optimizer_cls = self.cfg.alg.get(
"optimizer_cls", AdamW
"optimizer_cls", AdamW8bit
) # TODO: make optimizer class
self.optimizer_kwargs = self.cfg.alg.optimizer_kwargs
if self.reward_cfg is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/tril/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from tril.algorithms.aggrevated import AGGREVATED
from tril.algorithms.bc import BC
from tril.algorithms.ppo_pp import PPO_PP
from tril.algorithms.d2lols import D2LOLS
from tril.algorithms.gail import GAIL
from tril.algorithms.lols import LOLS
from tril.algorithms.ppo import PPO
from tril.algorithms.ppo_pp import PPO_PP


class AlgorithmRegistry:
Expand Down
3 changes: 3 additions & 0 deletions src/tril/algorithms/base_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _setup(self):

self.metrics = build_metrics(self.cfg.get("eval_metrics", []), self.accelerator)
self.samples_by_split = build_task(self.task_cfg)
import pdb

pdb.set_trace()

if hasattr(self.agent.reward, "_spice_metric"):
assert self.agent.reward is not None
Expand Down
67 changes: 42 additions & 25 deletions src/tril/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
Expand Down Expand Up @@ -69,55 +71,70 @@ def gen_split_name(split: str):

class TLDR(BaseTask):
@classmethod
def prepare(cls, split: str, tokenizer_id: str, max_prompt_length: int):
def prepare(
cls,
split: str,
tokenizer_id: str,
max_prompt_length: int,
n_samples: Dict[str, int] = {"valid": 100, "test": 500},
):
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_id
) # NOTE: truncation side | right, padding side | left
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"

ds = load_dataset("CarperAI/openai_summarize_tldr")
split_name = TLDR.gen_split_name(split)
samples = []
# TODO: Cache this processing if we can
for ix, item in enumerate(
tqdm(ds[split_name], desc=f"Preprocessing {split} Prompts")
):
# TODO: arguments
if split == "val" and ix == 500:
break
if split == "test" and ix == 100:
break
# Process Prompt
prompt = item["prompt"]
tmp = tokenizer.decode(
def process_prompts(example, idxs):
prompt = example["prompt"]
processed_prompt = [p.split("TL;DR:")[0] for p in prompt]
tmp = tokenizer.batch_decode(
tokenizer(
prompt.split("TL;DR:")[0],
processed_prompt,
truncation=True,
max_length=max_prompt_length
- 5, # to make sure "TL;DR" dont get truncated
add_special_tokens=False,
)["input_ids"],
skip_special_tokens=True,
).strip()
tmp = tmp + "\nTL;DR:"
tmp = tokenizer.decode(
)
tmp = [t.strip() + "\nTL;DR:" for t in tmp]
tmp = tokenizer.batch_decode(
tokenizer(
tmp,
truncation=True,
max_length=max_prompt_length,
add_special_tokens=False,
)["input_ids"],
skip_special_tokens=True,
).strip()
)
tmp = [t.strip() for t in tmp]
return {"id": idxs, "prompt": tmp, "label": example["label"]}

ds = load_dataset("CarperAI/openai_summarize_tldr")
split_name = TLDR.gen_split_name(split)
samples = []

# Map does caching
split_ds = ds[split_name].map(
process_prompts, with_indices=True, batched=True, batch_size=1000
)
n_split = n_samples.get(split_name, len(split_ds))
for prompt, label, ids in tqdm(
zip(
split_ds[:n_split]["prompt"],
split_ds[:n_split]["label"],
split_ds[:n_split]["id"],
),
desc=f"Preprocessing {split} Prompts",
total=n_split,
):
# Create Sample
sample = Sample(
id=ix,
prompt_or_input_text=tmp,
references=[item["label"]],
meta_data={"reference": item["label"]},
id=ids,
prompt_or_input_text=prompt,
references=[label],
meta_data={"reference": label},
)
samples.append(sample)
task_instance = cls(samples)
Expand Down

0 comments on commit af7fe52

Please sign in to comment.