Skip to content

Commit

Permalink
moss-rlhf code init
Browse files Browse the repository at this point in the history
  • Loading branch information
Ablustrund committed Jul 10, 2023
1 parent e511509 commit f1cdbc2
Show file tree
Hide file tree
Showing 17 changed files with 2,383 additions and 2 deletions.
19 changes: 18 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,19 @@
.DS_Store
**/.DS_Store
**/.DS_Store
__pycache__/
outputs/
test*.sh
*.csv
.idea
test_code*.py
data_helper_deprecated.py
deploy/
temp*
data/*
models/*
tensorboard_log/*
tmp/
log/
data/
debug/
inference/
407 changes: 407 additions & 0 deletions DATA_LICENSE

Large diffs are not rendered by default.

208 changes: 208 additions & 0 deletions MODEL_LICENSE

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
# MOSS-RLHF

Open source code and models are comming soon.
## *MOSS-RLHF & "Secrets of RLHF in Large Language Models Part II: PPO"*

<p align="center" width="100%">
<img src="./assets/img/moss.png" alt="MOSS" style="width: 50%; min-width: 300px; display: block; margin: auto;">

[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-brightgreen.svg)](./LICENSE)
[![Data License](https://img.shields.io/badge/Data%20License-CC%20BY--NC%204.0-blue.svg)](./DATA_LICENSE)
[![Model License](https://img.shields.io/badge/Model%20License-GNU%20AGPL%203.0-red.svg)](./MODEL_LICENSE)

This is the open-source code repository for the technical reports: "Secrets of RLHF in Large Language Models Part II: PPO"

<img style="width: 90%; min-width: 500px; display: block; margin: auto; margin-bottom: 20px" alt="MOSS-RLHF" src="./assets/img/img1.jpg">


## Open-source List
- Two 7B reward model based on openChineseLlama and Llama-7B, respectively.
- Open source code for RL training in large language models.
- ...

## Getting Started

TODO, To be finalised before 15. July 2023
Empty file added __init__.py
Empty file.
23 changes: 23 additions & 0 deletions accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
machine_rank: 0
main_process_ip: 10.176.98.78
main_process_port: 10532
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'bf16'
num_machines: 1
num_processes: 7
rdzv_backend: static
same_network: true
use_cpu: false
Binary file added assets/img/img1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/img/moss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse

def parse_args():
parser = argparse.ArgumentParser(description='MOSS-RLHF @Fudan NLP Group')

# Path
parser.add_argument('--model_save_path', type=str, default='', help='checkpoint path, used for save model and training')
parser.add_argument('--policy_model_path', type=str, default='', help='policy model and reference model path')
parser.add_argument('--critic_model_path', type=str, default='', help='critic model and reward model path')
parser.add_argument('--tokenizer_name_or_path', type=str, default='/huggingface_models/open-chinese-llama-7b', help='tokenizer name or path')
parser.add_argument('--data_path', type=str, default='./data', help='dataset for training and validation')
parser.add_argument('--logdir', type=str, default=None, help='path to save tensorboard logs')

# Training
parser.add_argument('--lr', type=float, default=5e-7, help='learning rate of policy model')
parser.add_argument('--critic_lr', type=float, default=15e-7, help='learning rate of critic model')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--batch_size', type=int, default=32, help='training batch size, *NOT* for sampling from env')
parser.add_argument('--train_steps', type=int, default=5000, help='train steps')
parser.add_argument('--warmup_steps', type=int, default=500, help='warmup steps')
parser.add_argument('--save_per_step', type=int, default=100, help='save ckpt per steps')
parser.add_argument('--beta1', type=float, default=0.9, help='adam')
parser.add_argument('--beta2', type=float, default=0.95, help='adam')
parser.add_argument('--eps', type=float, default=1e-6, help='optimizer')
parser.add_argument('--num_workers', type=int, default=1, help='dataloader')
parser.add_argument('--num_prefetch', type=int, default=32, help='dataloader')
parser.add_argument('--maxlen_prompt', type=int, default=2048, help='max len for training, including model prompt and response')
parser.add_argument('--gradient_checkpoint', action='store_true', help='deepspeed')

# PPO in LLMs
parser.add_argument('--num_rollouts', type=int, default=128, help='nums of samples in current replay buffer')
parser.add_argument('--rollout_batch_size', type=int, default=32, help='batch size of sampling from env')

parser.add_argument('--ppo_pretrain_data_path', type=str, default='', help='dataset folder path for pertrain loss of step3: rlhf')
parser.add_argument('--ppo_pretrain_data_type', type=str, default='sft', choices=['sft', 'pretrain'], help='dataset folder path for pertrain loss of step3: rlhf')
parser.add_argument('--ppo_pretrain_batch_size_ratio', type=int, default=1, help='ppo batch size ratio')
parser.add_argument('--ppo_pretrain_loss_weight', type=float, default=0., help='add pretrain loss in PPO training: ppo-rtx')
parser.add_argument('--kl_penalty_weight', type=float, default=0.02, help='kl penalty')
parser.add_argument('--advantage_clip', type=float, default=0.5, help='clip advantage')
parser.add_argument('--vf_loss_weight', type=float, default=1., help='vf loss weight')
parser.add_argument('--entropy_loss_weight', type=float, default=0., help='entropy loss weight')
parser.add_argument('--reward_clip', type=float, default=10., help='reward clip')
parser.add_argument('--entropy_clip', type=float, default=35., help='entropy loss clip')
parser.add_argument('--pg_clip', type=float, default=0.2, help='pg loss clip')
parser.add_argument('--value_clip', type=float, default=0.2, help='value clip for critic model')
parser.add_argument('--gamma', type=float, default=1., help='GAE in PPO')
parser.add_argument('--lam', type=float, default=0.95, help='GAE in PPO')

# Trick and method options for PPO
parser.add_argument('--use_reward_clip', action='store_true', help='use reward clip')
parser.add_argument('--use_reward_scaling', action='store_true', help='use reward scaling')
parser.add_argument('--use_reward_norm', action='store_true', help='user reward norm')
parser.add_argument('--use_critic_loss_clip', action='store_true', help='use critic loss clip')
parser.add_argument('--use_policy_loss_clip', action='store_true', help='use policy loss clip')
parser.add_argument('--use_advantage_norm', action='store_true', help='use advantage norm')
parser.add_argument('--use_advantage_clip', action='store_true', help='use advantage clip')
parser.add_argument('--use_ppo_pretrain_loss', action='store_true', help='use ppo pretrain loss')
parser.add_argument('--use_entropy_loss', action='store_true', help='use ppo entropy loss')

# Sample from env
parser.add_argument('--maxlen_res', type=int, default=128, help='max len for model response')
parser.add_argument('--temperature', type=float, default=0.8, help='temperature')
parser.add_argument('--repetition_penalty', type=float, default=1.1, help='repetition penalty')
parser.add_argument('--topp', type=float, default=0.9, help='nucleus sampling')

opt = parser.parse_args()

return opt


185 changes: 185 additions & 0 deletions metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import List, Optional, Any, Dict
import math
from accelerate import Accelerator
import torch
from torch.utils.tensorboard import SummaryWriter

class Metric:
def __init__(self):
pass

def add(self, val):
raise NotImplementedError

def val(self) -> float:
raise NotImplementedError

def reset(self):
raise NotImplementedError

def compute(self, val: Any):
return val

def __add__(self, other):
raise NotImplementedError

def __radd__(self, other):
return self.__add__(other)


class MeanMetric(Metric):
def __init__(self, num=0, denom=0):
self.numerator = num
self.denominator: int = denom

def add(self, val: Any):
self.numerator += self.compute(val)
self.denominator += 1

def many(self, vals: List[Any], denoms: Optional[List[int]] = None):
if denoms is None:
denoms = [1] * len(vals)
assert len(vals) == len(denoms)

for v, n in zip(vals, denoms):
self.numerator += self.compute(v)
self.denominator += n

def val(self):
if self.denominator == 0:
return 0
return self.numerator / self.denominator

def reset(self):
self.numerator = self.denominator = 0

def __add__(self, other: 'MeanMetric'):
return MeanMetric(self.numerator + other.numerator, self.denominator + other.denominator)

class SumMetric(Metric):
def __init__(self, sum_=0):
self.sum_ = sum_

def add(self, val):
self.sum_ += self.compute(val)

def many(self, vals: List[Any]):
self.sum_ += sum(self.compute(v) for v in vals)

def val(self):
return self.sum_

def reset(self):
self.sum_ = 0

def __add__(self, other: 'SumMetric'):
return SumMetric(self.sum_ + other.sum_)


class RealtimeMetric(Metric):
def __init__(self, val=0):
self.v = val

def add(self, val):
self.v = self.compute(val)

def many(self, vals: List[Any]):
self.add(vals[-1])

def val(self):
return self.v

def reset(self):
self.v = 0

def __add__(self, other):
return RealtimeMetric(self.v)

class PPLMetric(MeanMetric):
def val(self):
try:
return math.exp(super().val())
except OverflowError:
return super().val()

def __add__(self, other):
return PPLMetric(self.numerator + other.numerator, self.denominator + other.denominator)


class Metrics():
tb_writer = None
def __init__(self, opt: Dict[str, Any], accelerator, mode='train'):
self.metrics = {}
self.mode = mode
self.opt = opt
self.accelerator = accelerator

if Metrics.tb_writer is None and opt.logdir is not None and self.accelerator.is_main_process:
Metrics.tb_writer = SummaryWriter(opt.logdir)

def create_metric(self, metric_name: str, metric_obj: Metric):
assert metric_name not in self.metrics
self.metrics[metric_name] = metric_obj

def record_metric(self, metric_name: str, val: Any):
self.metrics[metric_name].add(val)

def record_metric_many(self, metric_name: str, vals: List[Any], counts: Optional[List[int]] = None):
if counts is None:
self.metrics[metric_name].many(vals)
else:
self.metrics[metric_name].many(vals, counts)

def reset(self, no_reset = ['global_exs']):
for k, v in self.metrics.items():
if k not in no_reset:
v.reset()

def all_gather_metrics(self):
with torch.no_grad():
metrics_tensor = {k: torch.tensor([v.val()], device=self.accelerator.device) for k, v in self.metrics.items()}

if self.accelerator.use_distributed:
gathered_metrics = self.accelerator.gather(metrics_tensor)
for metric_name, gathered_tensor in gathered_metrics.items():
if metric_name == 'global_exs':
gathered_metrics[metric_name] = gathered_tensor.sum()
else:
gathered_metrics[metric_name] = gathered_tensor.float().mean()
else:
gathered_metrics = metrics_tensor

gathered_metrics = {k: v.item() for k, v in gathered_metrics.items()}
return gathered_metrics

def write_tensorboard(self, global_step, gathered_metrics: Dict[str, float] = None):
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics
if self.tb_writer is not None:
for k, scalar in results.items():
title = f"{k}/{'train' if 'train' == self.mode else 'eval'}"
self.tb_writer.add_scalar(tag=title, scalar_value=scalar, global_step=global_step)

def flush(self):
if self.tb_writer is not None:
self.tb_writer.flush()

def display(self, global_step, data_size = None, gathered_metrics: Dict[str, float] = None):
if not self.accelerator.is_main_process:
return
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics
log_str = ''
if data_size is not None and 'global_exs' in results:
print(f"=========== Step: {global_step}, Epoch: {(results['global_exs'] / data_size):.2f} ===========")
else:
print(f'=========== Step: {global_step} ===========')
for k, value in results.items():
if isinstance(value, float):
if k == 'lr':
value = f'{value:.3e}'
else:
value = f'{value:.4f}'
log_str += f'{k}: {value}\t'
print(log_str)
return results


Empty file added ppo/__init__.py
Empty file.
Loading

0 comments on commit f1cdbc2

Please sign in to comment.