Skip to content

Commit

Permalink
vLLM + LLaMA works well
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 9, 2024
1 parent b9a83f8 commit 3c6dafd
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 83 deletions.
40 changes: 19 additions & 21 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
from transformers import get_constant_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from typing import Optional, List, Dict, Any, Tuple, Generator
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
Expand Down Expand Up @@ -353,18 +354,8 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:
return list(all_pos_premises)


_SPACES_REGEX = re.compile(r"\s+", re.DOTALL)


def normalize_spaces(s: str) -> str:
"""Repalce any consecutive block of whitespace characters in ``s`` with a single whitespace."""
return _SPACES_REGEX.sub(" ", s).strip()


def format_tactic(annot_tac: str, provenances, normalize: bool) -> str:
def format_tactic(annot_tac: str, provenances) -> str:
"""Use full names for the all <a>...</a>."""
if normalize:
annot_tac = normalize_spaces(annot_tac)
if len(provenances) == 0:
return annot_tac

Expand Down Expand Up @@ -412,22 +403,30 @@ def format_augmented_state(


def get_optimizers(
parameters, trainer: pl.Trainer, lr: float) -> Dict[str, Any]:
parameters, trainer: pl.Trainer, lr: float, warmup_steps: int
) -> Dict[str, Any]:
"""Return an AdamW optimizer with cosine warmup learning rate schedule."""
strategy = trainer.strategy

if isinstance(strategy, DeepSpeedStrategy):
if "offload_optimizer" in strategy.config["zero_optimization"]:
logger.info("Optimizing with DeepSpeedCPUAdam")
return DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
optimizer = DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
else:
logger.info("Optimizing with FusedAdam")
return FusedAdam(parameters, lr=lr, adam_w_mode=True)
optimizer = FusedAdam(parameters, lr=lr, adam_w_mode=True)
else:
logger.info("Optimizing with AdamW")
return torch.optim.AdamW(parameters, lr=lr)
optimizer = torch.optim.AdamW(parameters, lr=lr)


scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}


def _is_deepspeed_checkpoint(path: str):
Expand All @@ -438,14 +437,13 @@ def _is_deepspeed_checkpoint(path: str):

def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool):
"""Handle DeepSpeed checkpoints in model loading."""
if not _is_deepspeed_checkpoint(ckpt_path):
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
else:
if _is_deepspeed_checkpoint(ckpt_path):
with tempfile.TemporaryDirectory() as dirname:
path = os.path.join(dirname, "lightning.cpkt")
convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path)
model = model_cls.load_from_checkpoint(path, strict=False)
model = model.to(device)
model = model_cls.load_from_checkpoint(path, strict=False).to(device)
else: # PyTorch Ligthning checkpoints
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
if freeze:
model.freeze()
return model
Expand Down
7 changes: 2 additions & 5 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ trainer:
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_novel_premises
gradient_clip_val: 1.0
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -51,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/novel_premises/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
8 changes: 2 additions & 6 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ trainer:
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_random
save_dir: logs/generator_random
gradient_clip_val: 1.0
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -52,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
7 changes: 4 additions & 3 deletions generator/confs/torchtune-llama3-8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
output_dir: ./models/Meta-Llama-3-8B-finetuned/
output_dir: ./models/Meta-Llama-3-8B-finetuned-lr2e-5/
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
epochs: 1
epochs: 5

optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
warmup_steps: 2000

loss:
_component_: torch.nn.CrossEntropyLoss
Expand All @@ -83,6 +84,6 @@ metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: ReProver
log_dir: ${output_dir}
output_dir: ./logs/leandojo-llama3-finetune
output_dir: ./logs/Meta-Llama-3-8B-finetuned-lr2e-5/
log_every_n_steps: 1
log_peak_memory_stats: false
29 changes: 6 additions & 23 deletions generator/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,36 @@ def __init__(
self,
data_path: str,
corpus: Corpus,
keep_marks: bool,
preds: List[Dict[str, Any]],
max_inp_seq_len: int,
max_oup_seq_len: int,
p_drop: float,
normalize_tactics: bool,
tokenizer: ByT5Tokenizer,
is_train: bool,
) -> None:
super().__init__()
self.corpus = corpus
self.keep_marks = keep_marks
self.preds = preds
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len
self.p_drop = p_drop
self.tokenizer = tokenizer
self.is_train = is_train
self.data = self._load_data(data_path, normalize_tactics)
self.data = self._load_data(data_path)

def _load_data(self, data_path: str, normalize_tactics: bool) -> List[Example]:
def _load_data(self, data_path: str) -> List[Example]:
data = []
for thm in tqdm(json.load(open(data_path))):
for tac in thm["traced_tactics"]:
if "annotated_tactic" in tac:
tactic = format_tactic(*tac["annotated_tactic"], normalize_tactics)
else:
tactic = format_tactic(tac["tactic"], [], normalize_tactics)
if not self.keep_marks:
tactic = remove_marks(tactic)
tactic = remove_marks(tac["tactic"])
data.append(
{
"url": thm["url"],
"commit": thm["commit"],
"file_path": thm["file_path"],
"full_name": thm["full_name"],
"state": format_state(tac["state_before"]),
# "state": format_state(tac["state_before"]),
"state": tac["state_before"],
"tactic": tactic,
}
)
Expand All @@ -86,9 +79,7 @@ def __getitem__(self, idx: int) -> Example:
self.p_drop if self.is_train else 0.0,
)

if not self.keep_marks:
ex["state"] = remove_marks(ex["state"])

ex["state"] = remove_marks(ex["state"])
return ex

def collate(self, examples: List[Example]) -> Batch:
Expand Down Expand Up @@ -131,14 +122,12 @@ class GeneratorDataModule(pl.LightningDataModule):
def __init__(
self,
data_path: str,
keep_marks: bool,
model_name: str,
batch_size: int,
eval_batch_size: int,
max_inp_seq_len: int,
max_oup_seq_len: int,
p_drop: float,
normalize_tactics: bool,
num_workers: int,
corpus_path: Optional[str] = None,
preds_path: Optional[str] = None,
Expand All @@ -149,13 +138,11 @@ def __init__(
self.corpus = Corpus(corpus_path)
else:
self.corpus = None
self.keep_marks = keep_marks
self.batch_size = batch_size
self.eval_batch_size = eval_batch_size
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len
self.p_drop = p_drop
self.normalize_tactics = normalize_tactics
self.num_workers = num_workers
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand All @@ -177,12 +164,10 @@ def setup(self, stage: Optional[str] = None) -> None:
self.ds_train = GeneratorDataset(
os.path.join(self.data_path, "train.json"),
self.corpus,
self.keep_marks,
self.preds,
self.max_inp_seq_len,
self.max_oup_seq_len,
self.p_drop,
self.normalize_tactics,
self.tokenizer,
is_train=True,
)
Expand All @@ -191,12 +176,10 @@ def setup(self, stage: Optional[str] = None) -> None:
self.ds_val = GeneratorDataset(
os.path.join(self.data_path, "val.json"),
self.corpus,
self.keep_marks,
self.preds,
self.max_inp_seq_len,
self.max_oup_seq_len,
self.p_drop,
self.normalize_tactics,
self.tokenizer,
is_train=False,
)
Expand Down
7 changes: 7 additions & 0 deletions generator/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
StateDictType,
)
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import get_constant_schedule_with_warmup
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, utils
Expand Down Expand Up @@ -226,6 +228,9 @@ def setup(self, cfg: DictConfig) -> None:
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)
self._scheduler = get_constant_schedule_with_warmup(
self._optimizer, cfg.warmup_steps
)

self._loss_fn = config.instantiate(cfg.loss)

Expand Down Expand Up @@ -374,6 +379,7 @@ def _setup_optimizer(

if self._is_rank_zero:
log.info("Optimizer is initialized.")

return optimizer

def _setup_data(
Expand Down Expand Up @@ -534,6 +540,7 @@ def train(self) -> None:
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._scheduler.step()

# Update the number of steps when the weights are updated
self.global_step += 1
Expand Down
29 changes: 14 additions & 15 deletions generator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import ray
import math
import torch
import shutil
import openai
Expand All @@ -12,7 +11,6 @@
import pytorch_lightning as pl
from torchmetrics import Metric
from abc import ABC, abstractmethod
from vllm import SamplingParams
from typing import List, Dict, Any, Optional, Tuple
from transformers import T5ForConditionalGeneration, AutoTokenizer

Expand All @@ -25,6 +23,7 @@
format_augmented_state,
)
from retrieval.model import PremiseRetriever
from generator.template import StateTacticPairTemplate


torch.set_float32_matmul_precision("medium")
Expand Down Expand Up @@ -167,7 +166,9 @@ def training_step(self, batch, batch_idx: int):
return loss

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(self.parameters(), self.trainer, self.lr)
return get_optimizers(
self.parameters(), self.trainer, self.lr, self.warmup_steps
)

def _log_io_texts(
self,
Expand Down Expand Up @@ -542,13 +543,12 @@ def generate(
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
outputs = ray.get(
self.vllm_actor.generate.remote(
f"### State:\n{state}\n\n### Tactic:", num_samples
)
)
# prompt = StateTacticPairTemplate.format({"state": state})
# prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
prompt = f"[GOAL]\n{state}\n[PROOFSTEP]\n"
outputs = ray.get(self.vllm_actor.generate.remote(prompt, num_samples))
return [
(remove_marks(x.text), math.exp(x.cumulative_logprob))
(remove_marks(x.text).strip(), x.cumulative_logprob)
for x in outputs[0].outputs
]

Expand All @@ -560,12 +560,11 @@ def batch_generate(
theorem_pos: List[Pos],
num_samples: int,
) -> List[List[Tuple[str, float]]]:
inputs = [f"### State:\n{s}\n\n### Tactic:" for s in state]
outputs = ray.get(self.vllm_actor.generate.remote(inputs, num_samples))
# prompts = [StateTacticPairTemplate.format({"state": s}) for s in state]
# prompts = [f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{s}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" for s in state]
prompts = [f"[GOAL]\n{s}\n[PROOFSTEP]\n" for s in state]
outputs = ray.get(self.vllm_actor.generate.remote(prompts, num_samples))
return [
[
(remove_marks(x.text), math.exp(x.cumulative_logprob))
for x in oup.outputs
]
[(remove_marks(x.text).strip(), x.cumulative_logprob) for x in oup.outputs]
for oup in outputs
]
Loading

0 comments on commit 3c6dafd

Please sign in to comment.