Skip to content

Commit b80312c

Browse files
author
Vincent Moens
committed
[Algorithm] GRPO scripts
ghstack-source-id: 22a66ef Pull-Request-resolved: #2970
1 parent 6e4b550 commit b80312c

File tree

4 files changed

+634
-0
lines changed

4 files changed

+634
-0
lines changed

sota-implementations/llm/grpo.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""
6+
7+
# TODO: make sure VLLM_USE_V1=0
8+
9+
$ python -m pip install peft
10+
$ python -m pip install bitsandbytes
11+
$ python -m pip install flash_attn
12+
$ python -m pip install datasets
13+
14+
"""
15+
from __future__ import annotations
16+
17+
import gc
18+
from argparse import ArgumentParser
19+
import os
20+
import torch
21+
22+
import tqdm
23+
24+
from grpo_utils import get_inference_model, get_ref_model, get_train_model
25+
from torchrl import logger as torchrl_logger
26+
from torchrl.collectors.llm import LLMCollector
27+
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
28+
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
29+
30+
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform
31+
32+
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
33+
from torchrl.record import WandbLogger
34+
35+
if not os.getenv("VLLM_USE_V1", "0"):
36+
raise ValueError("VLLM_USE_V1=0 not set")
37+
38+
parser = ArgumentParser()
39+
parser.add_argument("--dataset", type=str, default="gsm8k")
40+
parser.add_argument("--batch_size", type=int, default=1)
41+
parser.add_argument("--epochs", type=int, default=1)
42+
parser.add_argument("--repeats", type=int, default=16)
43+
parser.add_argument("--num_envs", type=int, default=32)
44+
parser.add_argument("--steps_per_batch", type=int, default=64)
45+
parser.add_argument("--optim_batch_size", type=int, default=4)
46+
# parser.add_argument("--model_name", type=str, default="gpt2")
47+
parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B")
48+
parser.add_argument("--compile", action="store_true")
49+
parser.add_argument("--clip_grad_norm", type=float, default=0.5)
50+
parser.add_argument("--lr", type=float, default=1e-5)
51+
parser.add_argument("--kl_coef", type=float, default=1e-2)
52+
53+
54+
parser.add_argument("--gpu_memory_utilization", type=float, default=0.5)
55+
56+
torch.set_default_dtype(torch.bfloat16)
57+
58+
torch.set_default_device("cuda:0")
59+
60+
61+
def make_device_splits():
62+
# devices = list(range(torch.cuda.device_count()))
63+
# train_devices = devices[1:-1]
64+
# vllm_device = devices[0]
65+
# ref_device = devices[-1]
66+
devices = list(range(torch.cuda.device_count()))
67+
train_devices = devices[0:-2]
68+
vllm_devices = devices[-2:-1]
69+
ref_device = devices[-1]
70+
return train_devices, ref_device, vllm_devices
71+
72+
73+
if __name__ == "__main__":
74+
import ray
75+
76+
ray.init()
77+
78+
args = parser.parse_args()
79+
80+
train_devices, ref_device, vllm_devices = make_device_splits()
81+
82+
policy_training, train_tokenizer = get_train_model(args, train_devices)
83+
84+
# vLLM
85+
policy = get_inference_model(args, vllm_devices)
86+
87+
ref_model = get_ref_model(args, train_tokenizer, ref_device)
88+
89+
# Ref model
90+
91+
# Env
92+
env = GSM8KEnv(
93+
repeats=args.repeats, tokenizer=train_tokenizer, num_envs=args.num_envs
94+
)
95+
env = env.append_transform(
96+
KLRewardTransform(
97+
actor=ref_model,
98+
coef=args.kl_coef,
99+
device=ref_device,
100+
add_to_reward=False,
101+
)
102+
)
103+
104+
# replay buffer
105+
rb = ReplayBuffer(
106+
storage=LazyStackStorage(args.steps_per_batch),
107+
sampler=SamplerWithoutReplacement(),
108+
batch_size=args.optim_batch_size,
109+
)
110+
rb.append_transform(MCAdvantage(grpo_size=args.repeats))
111+
112+
# Collector
113+
114+
model_metadata = {
115+
k: (v.dtype, v.shape)
116+
for k, v in policy_training.model.merge_and_unload().state_dict().items()
117+
}
118+
updater = vLLMUpdater(
119+
master_address=None,
120+
master_port=None,
121+
model_metadata=model_metadata,
122+
)
123+
124+
collector = LLMCollector(
125+
env,
126+
policy=policy,
127+
dialog_turns_per_batch=args.steps_per_batch,
128+
total_dialog_turns=1_000_000,
129+
weight_updater=updater,
130+
)
131+
updater.maybe_init_group()
132+
133+
# Warmup
134+
torchrl_logger.info("Init weights update...")
135+
collector.update_policy_weights_(
136+
policy_training.model.merge_and_unload().state_dict(), worker_ids=[0]
137+
)
138+
torchrl_logger.info("done\n")
139+
140+
# Loss module
141+
loss_fn = GRPOLoss(actor_network=policy_training, kl_to_ref_coeff=args.kl_coef)
142+
143+
if args.compile:
144+
loss_fn = torch.compile(loss_fn)
145+
146+
# TODO: foreach=False to avoid "Tensors of the same index must be on the same device" error due to "auto" device map
147+
optim = torch.optim.AdamW(policy_training.model.parameters(), lr=args.lr)
148+
logger = WandbLogger(exp_name=args.model_name)
149+
150+
for i, trajs in enumerate(collector):
151+
torchrl_logger.info(f"Collected batch {i}: {trajs=}")
152+
153+
# rb.empty()
154+
trajs = trajs.reshape(-1)
155+
rb.extend(trajs)
156+
157+
# logging
158+
reward = torch.cat(rb[:].get(("next", "reward"), as_list=True)).mean()
159+
advantage = torch.cat(rb[:].get("advantage", as_list=True)).mean()
160+
kl_penalty = torch.cat(rb[:].get(("next", "kl_penalty"), as_list=True)).mean()
161+
seq_length = []
162+
for t in rb[:].get("tokens_response", as_list=True):
163+
seq_length.append(t.numel())
164+
seq_length = torch.tensor(seq_length, dtype=torch.float).mean()
165+
166+
if not reward:
167+
# no use in training a model without reward
168+
torchrl_logger.info("no reward - skipping")
169+
torch.cuda.empty_cache() # TODO: Test if this is needed
170+
continue
171+
logger.log_scalar("reward", reward)
172+
logger.log_scalar("advantage", advantage)
173+
logger.log_scalar("kl_penalty", kl_penalty)
174+
logger.log_scalar("seq_length", seq_length)
175+
176+
torchrl_logger.info(f"reward: {reward: 4.4f}")
177+
for i in range(args.epochs):
178+
torchrl_logger.info(f"epoch: {i}")
179+
pbar = tqdm.tqdm(total=len(rb) // args.optim_batch_size)
180+
for batch in rb:
181+
pbar.update(1)
182+
optim.zero_grad()
183+
batch = batch.to(train_devices[0])
184+
loss = loss_fn(batch)
185+
loss_val = loss.mean(reduce=True)
186+
loss_val.backward()
187+
gn = torch.nn.utils.clip_grad_norm_(
188+
policy_training.model.parameters(), args.clip_grad_norm
189+
)
190+
optim.step()
191+
192+
logger.log_scalar("ESS", loss.ESS)
193+
logger.log_scalar("loss_objective", loss.loss_objective)
194+
logger.log_scalar("clip_fraction", loss.clip_fraction)
195+
logger.log_scalar("kl_approx", loss.kl_approx)
196+
logger.log_scalar("grad_norm", gn)
197+
logger.log_scalar("entropy", loss.loss_entropy.mean())
198+
logger.log_scalar("kl_to_ref", loss.kl_to_ref.mean())
199+
logger.log_scalar("loss_kl_to_ref", loss.loss_kl_to_ref.mean())
200+
201+
# scaler.update()
202+
203+
gc.collect()
204+
torch.cuda.empty_cache()
205+
206+
torchrl_logger.info("Updating weights...")
207+
collector.update_policy_weights_(
208+
policy_weights=policy_training.model.merge_and_unload().state_dict(),
209+
worker_ids=[0],
210+
)
211+
gc.collect()
212+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)