forked from OpenLMLab/MOSS-RLHF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e511509
commit f1cdbc2
Showing
17 changed files
with
2,383 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,19 @@ | ||
.DS_Store | ||
**/.DS_Store | ||
**/.DS_Store | ||
__pycache__/ | ||
outputs/ | ||
test*.sh | ||
*.csv | ||
.idea | ||
test_code*.py | ||
data_helper_deprecated.py | ||
deploy/ | ||
temp* | ||
data/* | ||
models/* | ||
tensorboard_log/* | ||
tmp/ | ||
log/ | ||
data/ | ||
debug/ | ||
inference/ |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,24 @@ | ||
# MOSS-RLHF | ||
|
||
Open source code and models are comming soon. | ||
## *MOSS-RLHF & "Secrets of RLHF in Large Language Models Part II: PPO"* | ||
|
||
<p align="center" width="100%"> | ||
<img src="./assets/img/moss.png" alt="MOSS" style="width: 50%; min-width: 300px; display: block; margin: auto;"> | ||
|
||
[](./LICENSE) | ||
[](./DATA_LICENSE) | ||
[](./MODEL_LICENSE) | ||
|
||
This is the open-source code repository for the technical reports: "Secrets of RLHF in Large Language Models Part II: PPO" | ||
|
||
<img style="width: 90%; min-width: 500px; display: block; margin: auto; margin-bottom: 20px" alt="MOSS-RLHF" src="./assets/img/img1.jpg"> | ||
|
||
|
||
## Open-source List | ||
- Two 7B reward model based on openChineseLlama and Llama-7B, respectively. | ||
- Open source code for RL training in large language models. | ||
- ... | ||
|
||
## Getting Started | ||
|
||
TODO, To be finalised before 15. July 2023 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
compute_environment: LOCAL_MACHINE | ||
deepspeed_config: | ||
gradient_accumulation_steps: 1 | ||
gradient_clipping: 1.0 | ||
offload_optimizer_device: none | ||
offload_param_device: none | ||
zero3_init_flag: false | ||
zero_stage: 2 | ||
distributed_type: DEEPSPEED | ||
downcast_bf16: 'no' | ||
dynamo_backend: 'NO' | ||
fsdp_config: {} | ||
machine_rank: 0 | ||
main_process_ip: 10.176.98.78 | ||
main_process_port: 10532 | ||
main_training_function: main | ||
megatron_lm_config: {} | ||
mixed_precision: 'bf16' | ||
num_machines: 1 | ||
num_processes: 7 | ||
rdzv_backend: static | ||
same_network: true | ||
use_cpu: false |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import argparse | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='MOSS-RLHF @Fudan NLP Group') | ||
|
||
# Path | ||
parser.add_argument('--model_save_path', type=str, default='', help='checkpoint path, used for save model and training') | ||
parser.add_argument('--policy_model_path', type=str, default='', help='policy model and reference model path') | ||
parser.add_argument('--critic_model_path', type=str, default='', help='critic model and reward model path') | ||
parser.add_argument('--tokenizer_name_or_path', type=str, default='/huggingface_models/open-chinese-llama-7b', help='tokenizer name or path') | ||
parser.add_argument('--data_path', type=str, default='./data', help='dataset for training and validation') | ||
parser.add_argument('--logdir', type=str, default=None, help='path to save tensorboard logs') | ||
|
||
# Training | ||
parser.add_argument('--lr', type=float, default=5e-7, help='learning rate of policy model') | ||
parser.add_argument('--critic_lr', type=float, default=15e-7, help='learning rate of critic model') | ||
parser.add_argument('--seed', type=int, default=42, help='seed') | ||
parser.add_argument('--batch_size', type=int, default=32, help='training batch size, *NOT* for sampling from env') | ||
parser.add_argument('--train_steps', type=int, default=5000, help='train steps') | ||
parser.add_argument('--warmup_steps', type=int, default=500, help='warmup steps') | ||
parser.add_argument('--save_per_step', type=int, default=100, help='save ckpt per steps') | ||
parser.add_argument('--beta1', type=float, default=0.9, help='adam') | ||
parser.add_argument('--beta2', type=float, default=0.95, help='adam') | ||
parser.add_argument('--eps', type=float, default=1e-6, help='optimizer') | ||
parser.add_argument('--num_workers', type=int, default=1, help='dataloader') | ||
parser.add_argument('--num_prefetch', type=int, default=32, help='dataloader') | ||
parser.add_argument('--maxlen_prompt', type=int, default=2048, help='max len for training, including model prompt and response') | ||
parser.add_argument('--gradient_checkpoint', action='store_true', help='deepspeed') | ||
|
||
# PPO in LLMs | ||
parser.add_argument('--num_rollouts', type=int, default=128, help='nums of samples in current replay buffer') | ||
parser.add_argument('--rollout_batch_size', type=int, default=32, help='batch size of sampling from env') | ||
|
||
parser.add_argument('--ppo_pretrain_data_path', type=str, default='', help='dataset folder path for pertrain loss of step3: rlhf') | ||
parser.add_argument('--ppo_pretrain_data_type', type=str, default='sft', choices=['sft', 'pretrain'], help='dataset folder path for pertrain loss of step3: rlhf') | ||
parser.add_argument('--ppo_pretrain_batch_size_ratio', type=int, default=1, help='ppo batch size ratio') | ||
parser.add_argument('--ppo_pretrain_loss_weight', type=float, default=0., help='add pretrain loss in PPO training: ppo-rtx') | ||
parser.add_argument('--kl_penalty_weight', type=float, default=0.02, help='kl penalty') | ||
parser.add_argument('--advantage_clip', type=float, default=0.5, help='clip advantage') | ||
parser.add_argument('--vf_loss_weight', type=float, default=1., help='vf loss weight') | ||
parser.add_argument('--entropy_loss_weight', type=float, default=0., help='entropy loss weight') | ||
parser.add_argument('--reward_clip', type=float, default=10., help='reward clip') | ||
parser.add_argument('--entropy_clip', type=float, default=35., help='entropy loss clip') | ||
parser.add_argument('--pg_clip', type=float, default=0.2, help='pg loss clip') | ||
parser.add_argument('--value_clip', type=float, default=0.2, help='value clip for critic model') | ||
parser.add_argument('--gamma', type=float, default=1., help='GAE in PPO') | ||
parser.add_argument('--lam', type=float, default=0.95, help='GAE in PPO') | ||
|
||
# Trick and method options for PPO | ||
parser.add_argument('--use_reward_clip', action='store_true', help='use reward clip') | ||
parser.add_argument('--use_reward_scaling', action='store_true', help='use reward scaling') | ||
parser.add_argument('--use_reward_norm', action='store_true', help='user reward norm') | ||
parser.add_argument('--use_critic_loss_clip', action='store_true', help='use critic loss clip') | ||
parser.add_argument('--use_policy_loss_clip', action='store_true', help='use policy loss clip') | ||
parser.add_argument('--use_advantage_norm', action='store_true', help='use advantage norm') | ||
parser.add_argument('--use_advantage_clip', action='store_true', help='use advantage clip') | ||
parser.add_argument('--use_ppo_pretrain_loss', action='store_true', help='use ppo pretrain loss') | ||
parser.add_argument('--use_entropy_loss', action='store_true', help='use ppo entropy loss') | ||
|
||
# Sample from env | ||
parser.add_argument('--maxlen_res', type=int, default=128, help='max len for model response') | ||
parser.add_argument('--temperature', type=float, default=0.8, help='temperature') | ||
parser.add_argument('--repetition_penalty', type=float, default=1.1, help='repetition penalty') | ||
parser.add_argument('--topp', type=float, default=0.9, help='nucleus sampling') | ||
|
||
opt = parser.parse_args() | ||
|
||
return opt | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from typing import List, Optional, Any, Dict | ||
import math | ||
from accelerate import Accelerator | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
class Metric: | ||
def __init__(self): | ||
pass | ||
|
||
def add(self, val): | ||
raise NotImplementedError | ||
|
||
def val(self) -> float: | ||
raise NotImplementedError | ||
|
||
def reset(self): | ||
raise NotImplementedError | ||
|
||
def compute(self, val: Any): | ||
return val | ||
|
||
def __add__(self, other): | ||
raise NotImplementedError | ||
|
||
def __radd__(self, other): | ||
return self.__add__(other) | ||
|
||
|
||
class MeanMetric(Metric): | ||
def __init__(self, num=0, denom=0): | ||
self.numerator = num | ||
self.denominator: int = denom | ||
|
||
def add(self, val: Any): | ||
self.numerator += self.compute(val) | ||
self.denominator += 1 | ||
|
||
def many(self, vals: List[Any], denoms: Optional[List[int]] = None): | ||
if denoms is None: | ||
denoms = [1] * len(vals) | ||
assert len(vals) == len(denoms) | ||
|
||
for v, n in zip(vals, denoms): | ||
self.numerator += self.compute(v) | ||
self.denominator += n | ||
|
||
def val(self): | ||
if self.denominator == 0: | ||
return 0 | ||
return self.numerator / self.denominator | ||
|
||
def reset(self): | ||
self.numerator = self.denominator = 0 | ||
|
||
def __add__(self, other: 'MeanMetric'): | ||
return MeanMetric(self.numerator + other.numerator, self.denominator + other.denominator) | ||
|
||
class SumMetric(Metric): | ||
def __init__(self, sum_=0): | ||
self.sum_ = sum_ | ||
|
||
def add(self, val): | ||
self.sum_ += self.compute(val) | ||
|
||
def many(self, vals: List[Any]): | ||
self.sum_ += sum(self.compute(v) for v in vals) | ||
|
||
def val(self): | ||
return self.sum_ | ||
|
||
def reset(self): | ||
self.sum_ = 0 | ||
|
||
def __add__(self, other: 'SumMetric'): | ||
return SumMetric(self.sum_ + other.sum_) | ||
|
||
|
||
class RealtimeMetric(Metric): | ||
def __init__(self, val=0): | ||
self.v = val | ||
|
||
def add(self, val): | ||
self.v = self.compute(val) | ||
|
||
def many(self, vals: List[Any]): | ||
self.add(vals[-1]) | ||
|
||
def val(self): | ||
return self.v | ||
|
||
def reset(self): | ||
self.v = 0 | ||
|
||
def __add__(self, other): | ||
return RealtimeMetric(self.v) | ||
|
||
class PPLMetric(MeanMetric): | ||
def val(self): | ||
try: | ||
return math.exp(super().val()) | ||
except OverflowError: | ||
return super().val() | ||
|
||
def __add__(self, other): | ||
return PPLMetric(self.numerator + other.numerator, self.denominator + other.denominator) | ||
|
||
|
||
class Metrics(): | ||
tb_writer = None | ||
def __init__(self, opt: Dict[str, Any], accelerator, mode='train'): | ||
self.metrics = {} | ||
self.mode = mode | ||
self.opt = opt | ||
self.accelerator = accelerator | ||
|
||
if Metrics.tb_writer is None and opt.logdir is not None and self.accelerator.is_main_process: | ||
Metrics.tb_writer = SummaryWriter(opt.logdir) | ||
|
||
def create_metric(self, metric_name: str, metric_obj: Metric): | ||
assert metric_name not in self.metrics | ||
self.metrics[metric_name] = metric_obj | ||
|
||
def record_metric(self, metric_name: str, val: Any): | ||
self.metrics[metric_name].add(val) | ||
|
||
def record_metric_many(self, metric_name: str, vals: List[Any], counts: Optional[List[int]] = None): | ||
if counts is None: | ||
self.metrics[metric_name].many(vals) | ||
else: | ||
self.metrics[metric_name].many(vals, counts) | ||
|
||
def reset(self, no_reset = ['global_exs']): | ||
for k, v in self.metrics.items(): | ||
if k not in no_reset: | ||
v.reset() | ||
|
||
def all_gather_metrics(self): | ||
with torch.no_grad(): | ||
metrics_tensor = {k: torch.tensor([v.val()], device=self.accelerator.device) for k, v in self.metrics.items()} | ||
|
||
if self.accelerator.use_distributed: | ||
gathered_metrics = self.accelerator.gather(metrics_tensor) | ||
for metric_name, gathered_tensor in gathered_metrics.items(): | ||
if metric_name == 'global_exs': | ||
gathered_metrics[metric_name] = gathered_tensor.sum() | ||
else: | ||
gathered_metrics[metric_name] = gathered_tensor.float().mean() | ||
else: | ||
gathered_metrics = metrics_tensor | ||
|
||
gathered_metrics = {k: v.item() for k, v in gathered_metrics.items()} | ||
return gathered_metrics | ||
|
||
def write_tensorboard(self, global_step, gathered_metrics: Dict[str, float] = None): | ||
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics | ||
if self.tb_writer is not None: | ||
for k, scalar in results.items(): | ||
title = f"{k}/{'train' if 'train' == self.mode else 'eval'}" | ||
self.tb_writer.add_scalar(tag=title, scalar_value=scalar, global_step=global_step) | ||
|
||
def flush(self): | ||
if self.tb_writer is not None: | ||
self.tb_writer.flush() | ||
|
||
def display(self, global_step, data_size = None, gathered_metrics: Dict[str, float] = None): | ||
if not self.accelerator.is_main_process: | ||
return | ||
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics | ||
log_str = '' | ||
if data_size is not None and 'global_exs' in results: | ||
print(f"=========== Step: {global_step}, Epoch: {(results['global_exs'] / data_size):.2f} ===========") | ||
else: | ||
print(f'=========== Step: {global_step} ===========') | ||
for k, value in results.items(): | ||
if isinstance(value, float): | ||
if k == 'lr': | ||
value = f'{value:.3e}' | ||
else: | ||
value = f'{value:.4f}' | ||
log_str += f'{k}: {value}\t' | ||
print(log_str) | ||
return results | ||
|
||
|
Empty file.
Oops, something went wrong.