Skip to content

Commit c9b036a

Browse files
author
Vincent Moens
committed
[Feature] A2C compatibility with compile
ghstack-source-id: 04176b8 Pull Request resolved: #2464
1 parent b116151 commit c9b036a

File tree

14 files changed

+317
-189
lines changed

14 files changed

+317
-189
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
) # Anything from 2.5, incl. nightlies, allows for fullgraph
5151

5252

53-
@pytest.fixture(scope="module")
53+
@pytest.fixture(scope="module", autouse=True)
5454
def set_default_device():
5555
cur_device = torch.get_default_device()
5656
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

sota-implementations/a2c/a2c_atari.py

Lines changed: 101 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import hydra
6+
from tensordict.nn import CudaGraphModule
67
from torchrl._utils import logger as torchrl_logger
78
from torchrl.record import VideoRecorder
89

@@ -15,17 +16,21 @@ def main(cfg: "DictConfig"): # noqa: F821
1516
import torch.optim
1617
import tqdm
1718

18-
from tensordict import TensorDict
19+
from torchrl._utils import timeit
1920
from torchrl.collectors import SyncDataCollector
20-
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
21+
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2122
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
2223
from torchrl.envs import ExplorationType, set_exploration_type
2324
from torchrl.objectives import A2CLoss
2425
from torchrl.objectives.value.advantages import GAE
2526
from torchrl.record.loggers import generate_exp_name, get_logger
2627
from utils_atari import eval_model, make_parallel_env, make_ppo_models
2728

28-
device = "cpu" if not torch.cuda.device_count() else "cuda"
29+
device = cfg.loss.device
30+
if not device:
31+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
32+
else:
33+
device = torch.device(device)
2934

3035
# Correct for frame_skip
3136
frame_skip = 4
@@ -35,28 +40,17 @@ def main(cfg: "DictConfig"): # noqa: F821
3540
test_interval = cfg.logger.test_interval // frame_skip
3641

3742
# Create models (check utils_atari.py)
38-
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
43+
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
3944
actor, critic, critic_head = (
4045
actor.to(device),
4146
critic.to(device),
4247
critic_head.to(device),
4348
)
4449

45-
# Create collector
46-
collector = SyncDataCollector(
47-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
48-
policy=actor,
49-
frames_per_batch=frames_per_batch,
50-
total_frames=total_frames,
51-
device=device,
52-
storing_device=device,
53-
max_frames_per_traj=-1,
54-
)
55-
5650
# Create data buffer
5751
sampler = SamplerWithoutReplacement()
5852
data_buffer = TensorDictReplayBuffer(
59-
storage=LazyMemmapStorage(frames_per_batch),
53+
storage=LazyTensorStorage(frames_per_batch, device=device),
6054
sampler=sampler,
6155
batch_size=mini_batch_size,
6256
)
@@ -67,6 +61,7 @@ def main(cfg: "DictConfig"): # noqa: F821
6761
lmbda=cfg.loss.gae_lambda,
6862
value_network=critic,
6963
average_gae=True,
64+
vectorized=not cfg.loss.compile,
7065
)
7166
loss_module = A2CLoss(
7267
actor_network=actor,
@@ -83,9 +78,10 @@ def main(cfg: "DictConfig"): # noqa: F821
8378
# Create optimizer
8479
optim = torch.optim.Adam(
8580
loss_module.parameters(),
86-
lr=cfg.optim.lr,
81+
lr=torch.tensor(cfg.optim.lr, device=device),
8782
weight_decay=cfg.optim.weight_decay,
8883
eps=cfg.optim.eps,
84+
capturable=device.type == "cuda",
8985
)
9086

9187
# Create logger
@@ -115,16 +111,71 @@ def main(cfg: "DictConfig"): # noqa: F821
115111
)
116112
test_env.eval()
117113

114+
# update function
115+
def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
116+
# Forward pass A2C loss
117+
loss = loss_module(batch)
118+
119+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
120+
121+
# Backward pass
122+
loss_sum.backward()
123+
gn = torch.nn.utils.clip_grad_norm_(
124+
loss_module.parameters(), max_norm=max_grad_norm
125+
)
126+
127+
# Update the networks
128+
optim.step()
129+
optim.zero_grad(set_to_none=True)
130+
131+
return (
132+
loss.select("loss_critic", "loss_entropy", "loss_objective")
133+
.detach()
134+
.set("grad_norm", gn)
135+
)
136+
137+
if cfg.loss.compile:
138+
compile_mode = cfg.loss.compile_mode
139+
if compile_mode in ("", None):
140+
if cfg.loss.cudagraphs:
141+
compile_mode = None
142+
else:
143+
compile_mode = "reduce-overhead"
144+
update = torch.compile(update, mode=compile_mode)
145+
actor = torch.compile(actor, mode=compile_mode)
146+
adv_module = torch.compile(adv_module, mode=compile_mode)
147+
148+
if cfg.loss.cudagraphs:
149+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
150+
actor = CudaGraphModule(actor)
151+
adv_module = CudaGraphModule(adv_module)
152+
153+
# Create collector
154+
collector = SyncDataCollector(
155+
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
156+
policy=actor,
157+
frames_per_batch=frames_per_batch,
158+
total_frames=total_frames,
159+
device=device,
160+
storing_device=device,
161+
policy_device=device,
162+
)
163+
118164
# Main loop
119165
collected_frames = 0
120166
num_network_updates = 0
121167
start_time = time.time()
122168
pbar = tqdm.tqdm(total=total_frames)
123169
num_mini_batches = frames_per_batch // mini_batch_size
124170
total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
171+
lr = cfg.optim.lr
125172

126173
sampling_start = time.time()
127-
for i, data in enumerate(collector):
174+
c_iter = iter(collector)
175+
for i in range(len(collector)):
176+
with timeit("collecting"):
177+
torch.compiler.cudagraph_mark_step_begin()
178+
data = next(c_iter)
128179

129180
log_info = {}
130181
sampling_time = time.time() - sampling_start
@@ -144,59 +195,53 @@ def main(cfg: "DictConfig"): # noqa: F821
144195
}
145196
)
146197

147-
losses = TensorDict({}, batch_size=[num_mini_batches])
198+
losses = []
148199
training_start = time.time()
149200

150201
# Compute GAE
151-
with torch.no_grad():
202+
with torch.no_grad(), timeit("advantage"):
152203
data = adv_module(data)
153204
data_reshape = data.reshape(-1)
154205

155206
# Update the data buffer
156-
data_buffer.extend(data_reshape)
157-
158-
for k, batch in enumerate(data_buffer):
159-
160-
# Get a data batch
161-
batch = batch.to(device)
162-
163-
# Linearly decrease the learning rate and clip epsilon
164-
alpha = 1.0
165-
if cfg.optim.anneal_lr:
166-
alpha = 1 - (num_network_updates / total_network_updates)
167-
for group in optim.param_groups:
168-
group["lr"] = cfg.optim.lr * alpha
169-
num_network_updates += 1
170-
171-
# Forward pass A2C loss
172-
loss = loss_module(batch)
173-
losses[k] = loss.select(
174-
"loss_critic", "loss_entropy", "loss_objective"
175-
).detach()
176-
loss_sum = (
177-
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
178-
)
179-
180-
# Backward pass
181-
loss_sum.backward()
182-
torch.nn.utils.clip_grad_norm_(
183-
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
184-
)
185-
186-
# Update the networks
187-
optim.step()
188-
optim.zero_grad()
189-
207+
with timeit("emptying"):
208+
data_buffer.empty()
209+
with timeit("extending"):
210+
data_buffer.extend(data_reshape)
211+
212+
with timeit("optim"):
213+
for batch in data_buffer:
214+
215+
# Linearly decrease the learning rate and clip epsilon
216+
with timeit("optim - lr"):
217+
alpha = 1.0
218+
if cfg.optim.anneal_lr:
219+
alpha = 1 - (num_network_updates / total_network_updates)
220+
for group in optim.param_groups:
221+
group["lr"].copy_(lr * alpha)
222+
223+
num_network_updates += 1
224+
225+
with timeit("optim - update"):
226+
torch.compiler.cudagraph_mark_step_begin()
227+
loss = update(batch)
228+
losses.append(loss)
229+
230+
if i % 200 == 0:
231+
timeit.print()
232+
timeit.erase()
190233
# Get training losses
191234
training_time = time.time() - training_start
192-
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
235+
losses = torch.stack(losses).float().mean()
236+
193237
for key, value in losses.items():
194238
log_info.update({f"train/{key}": value.item()})
195239
log_info.update(
196240
{
197-
"train/lr": alpha * cfg.optim.lr,
241+
"train/lr": lr * alpha,
198242
"train/sampling_time": sampling_time,
199243
"train/training_time": training_time,
244+
**timeit.todict(prefix="time"),
200245
}
201246
)
202247

@@ -223,7 +268,6 @@ def main(cfg: "DictConfig"): # noqa: F821
223268
for key, value in log_info.items():
224269
logger.log_scalar(key, value, collected_frames)
225270

226-
collector.update_policy_weights_()
227271
sampling_start = time.time()
228272

229273
collector.shutdown()

0 commit comments

Comments
 (0)