Skip to content

Commit faa948a

Browse files
author
Vincent Moens
committed
[Algorithm] GRPO scripts
ghstack-source-id: 6768b25 Pull-Request-resolved: #2970
1 parent 6ca216e commit faa948a

File tree

4 files changed

+605
-0
lines changed

4 files changed

+605
-0
lines changed

sota-implementations/llm/grpo.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""
6+
7+
# TODO: make sure VLLM_USE_V1=0
8+
9+
$ python -m pip install peft
10+
$ python -m pip install bitsandbytes
11+
$ python -m pip install flash_attn
12+
$ python -m pip install datasets
13+
14+
"""
15+
from __future__ import annotations
16+
17+
import gc
18+
import os
19+
from argparse import ArgumentParser
20+
21+
import torch
22+
23+
import tqdm
24+
25+
from grpo_utils import get_inference_model, get_ref_model, get_train_model
26+
from torchrl import logger as torchrl_logger
27+
from torchrl.collectors.llm import LLMCollector
28+
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
29+
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
30+
31+
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform
32+
33+
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
34+
from torchrl.record import WandbLogger
35+
36+
if not os.getenv("VLLM_USE_V1", "0"):
37+
raise ValueError("VLLM_USE_V1=0 not set")
38+
39+
parser = ArgumentParser()
40+
parser.add_argument("--dataset", type=str, default="gsm8k")
41+
parser.add_argument("--batch_size", type=int, default=1)
42+
parser.add_argument("--epochs", type=int, default=1)
43+
parser.add_argument("--repeats", type=int, default=16)
44+
parser.add_argument("--num_envs", type=int, default=32)
45+
parser.add_argument("--steps_per_batch", type=int, default=64)
46+
parser.add_argument("--optim_batch_size", type=int, default=4)
47+
# parser.add_argument("--model_name", type=str, default="gpt2")
48+
parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B")
49+
parser.add_argument("--compile", action="store_true")
50+
parser.add_argument("--clip_grad_norm", type=float, default=0.5)
51+
parser.add_argument("--lr", type=float, default=1e-5)
52+
parser.add_argument("--kl_coef", type=float, default=1e-2)
53+
54+
55+
parser.add_argument("--gpu_memory_utilization", type=float, default=0.5)
56+
57+
torch.set_default_dtype(torch.bfloat16)
58+
59+
torch.set_default_device("cuda:0")
60+
61+
62+
def make_device_splits():
63+
# devices = list(range(torch.cuda.device_count()))
64+
# train_devices = devices[1:-1]
65+
# vllm_device = devices[0]
66+
# ref_device = devices[-1]
67+
devices = list(range(torch.cuda.device_count()))
68+
train_devices = devices[0:-2]
69+
vllm_devices = devices[-2:-1]
70+
ref_device = devices[-1]
71+
return train_devices, ref_device, vllm_devices
72+
73+
74+
if __name__ == "__main__":
75+
import ray
76+
77+
ray.init()
78+
79+
args = parser.parse_args()
80+
81+
train_devices, ref_device, vllm_devices = make_device_splits()
82+
83+
policy_training, train_tokenizer = get_train_model(args, train_devices)
84+
85+
# vLLM
86+
policy = get_inference_model(args, vllm_devices)
87+
88+
ref_model = get_ref_model(args, train_tokenizer, ref_device)
89+
90+
# Ref model
91+
92+
# Env
93+
env = GSM8KEnv(
94+
repeats=args.repeats, tokenizer=train_tokenizer, num_envs=args.num_envs
95+
)
96+
env = env.append_transform(
97+
KLRewardTransform(
98+
actor=ref_model,
99+
coef=args.kl_coef,
100+
device=ref_device,
101+
add_to_reward=False,
102+
)
103+
)
104+
105+
# replay buffer
106+
rb = ReplayBuffer(
107+
storage=LazyStackStorage(args.steps_per_batch),
108+
sampler=SamplerWithoutReplacement(),
109+
batch_size=args.optim_batch_size,
110+
)
111+
rb.append_transform(MCAdvantage(grpo_size=args.repeats))
112+
113+
# Collector
114+
115+
model_metadata = {
116+
k: (v.dtype, v.shape)
117+
for k, v in policy_training.model.merge_and_unload().state_dict().items()
118+
}
119+
updater = vLLMUpdater(
120+
master_address=None,
121+
master_port=None,
122+
model_metadata=model_metadata,
123+
)
124+
125+
collector = LLMCollector(
126+
env,
127+
policy=policy,
128+
dialog_turns_per_batch=args.steps_per_batch,
129+
total_dialog_turns=1_000_000,
130+
weight_updater=updater,
131+
)
132+
updater.maybe_init_group()
133+
134+
# Warmup
135+
torchrl_logger.info("Init weights update...")
136+
collector.update_policy_weights_(
137+
policy_training.model.merge_and_unload().state_dict(), worker_ids=[0]
138+
)
139+
torchrl_logger.info("done\n")
140+
141+
# Loss module
142+
loss_fn = GRPOLoss(actor_network=policy_training, kl_to_ref_coeff=args.kl_coef)
143+
144+
if args.compile:
145+
loss_fn = torch.compile(loss_fn)
146+
147+
# TODO: foreach=False to avoid "Tensors of the same index must be on the same device" error due to "auto" device map
148+
optim = torch.optim.AdamW(policy_training.model.parameters(), lr=args.lr)
149+
logger = WandbLogger(exp_name=args.model_name)
150+
151+
for i, trajs in enumerate(collector):
152+
torchrl_logger.info(f"Collected batch {i}: {trajs=}")
153+
154+
# rb.empty()
155+
trajs = trajs.reshape(-1)
156+
rb.extend(trajs)
157+
158+
# logging
159+
reward = torch.cat(rb[:].get(("next", "reward"), as_list=True)).mean()
160+
advantage = torch.cat(rb[:].get("advantage", as_list=True)).mean()
161+
kl_penalty = torch.cat(rb[:].get(("next", "kl_penalty"), as_list=True)).mean()
162+
seq_length = []
163+
for t in rb[:].get("tokens_response", as_list=True):
164+
seq_length.append(t.numel())
165+
seq_length = torch.tensor(seq_length, dtype=torch.float).mean()
166+
167+
if not reward:
168+
# no use in training a model without reward
169+
torchrl_logger.info("no reward - skipping")
170+
torch.cuda.empty_cache() # TODO: Test if this is needed
171+
continue
172+
logger.log_scalar("reward", reward)
173+
logger.log_scalar("advantage", advantage)
174+
logger.log_scalar("kl_penalty", kl_penalty)
175+
logger.log_scalar("seq_length", seq_length)
176+
177+
torchrl_logger.info(f"reward: {reward: 4.4f}")
178+
for i in range(args.epochs):
179+
torchrl_logger.info(f"epoch: {i}")
180+
pbar = tqdm.tqdm(total=len(rb) // args.optim_batch_size)
181+
for batch in rb:
182+
pbar.update(1)
183+
optim.zero_grad()
184+
batch = batch.to(train_devices[0])
185+
loss = loss_fn(batch)
186+
loss_val = loss.mean(reduce=True)
187+
loss_val.backward()
188+
gn = torch.nn.utils.clip_grad_norm_(
189+
policy_training.model.parameters(), args.clip_grad_norm
190+
)
191+
optim.step()
192+
193+
logger.log_scalar("ESS", loss.ESS)
194+
logger.log_scalar("loss_objective", loss.loss_objective)
195+
logger.log_scalar("clip_fraction", loss.clip_fraction)
196+
logger.log_scalar("kl_approx", loss.kl_approx)
197+
logger.log_scalar("grad_norm", gn)
198+
logger.log_scalar("entropy", loss.loss_entropy.mean())
199+
logger.log_scalar("kl_to_ref", loss.kl_to_ref.mean())
200+
logger.log_scalar("loss_kl_to_ref", loss.loss_kl_to_ref.mean())
201+
202+
# scaler.update()
203+
204+
gc.collect()
205+
torch.cuda.empty_cache()
206+
207+
torchrl_logger.info("Updating weights...")
208+
collector.update_policy_weights_(
209+
policy_weights=policy_training.model.merge_and_unload().state_dict(),
210+
worker_ids=[0],
211+
)
212+
gc.collect()
213+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)