3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
import hydra
6
+ from tensordict .nn import CudaGraphModule
6
7
from torchrl ._utils import logger as torchrl_logger
7
8
from torchrl .record import VideoRecorder
8
9
@@ -15,17 +16,21 @@ def main(cfg: "DictConfig"): # noqa: F821
15
16
import torch .optim
16
17
import tqdm
17
18
18
- from tensordict import TensorDict
19
+ from torchrl . _utils import timeit
19
20
from torchrl .collectors import SyncDataCollector
20
- from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
21
+ from torchrl .data import LazyTensorStorage , TensorDictReplayBuffer
21
22
from torchrl .data .replay_buffers .samplers import SamplerWithoutReplacement
22
23
from torchrl .envs import ExplorationType , set_exploration_type
23
24
from torchrl .objectives import A2CLoss
24
25
from torchrl .objectives .value .advantages import GAE
25
26
from torchrl .record .loggers import generate_exp_name , get_logger
26
27
from utils_atari import eval_model , make_parallel_env , make_ppo_models
27
28
28
- device = "cpu" if not torch .cuda .device_count () else "cuda"
29
+ device = cfg .loss .device
30
+ if not device :
31
+ device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda:0" )
32
+ else :
33
+ device = torch .device (device )
29
34
30
35
# Correct for frame_skip
31
36
frame_skip = 4
@@ -35,28 +40,17 @@ def main(cfg: "DictConfig"): # noqa: F821
35
40
test_interval = cfg .logger .test_interval // frame_skip
36
41
37
42
# Create models (check utils_atari.py)
38
- actor , critic , critic_head = make_ppo_models (cfg .env .env_name )
43
+ actor , critic , critic_head = make_ppo_models (cfg .env .env_name , device = device )
39
44
actor , critic , critic_head = (
40
45
actor .to (device ),
41
46
critic .to (device ),
42
47
critic_head .to (device ),
43
48
)
44
49
45
- # Create collector
46
- collector = SyncDataCollector (
47
- create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
48
- policy = actor ,
49
- frames_per_batch = frames_per_batch ,
50
- total_frames = total_frames ,
51
- device = device ,
52
- storing_device = device ,
53
- max_frames_per_traj = - 1 ,
54
- )
55
-
56
50
# Create data buffer
57
51
sampler = SamplerWithoutReplacement ()
58
52
data_buffer = TensorDictReplayBuffer (
59
- storage = LazyMemmapStorage (frames_per_batch ),
53
+ storage = LazyTensorStorage (frames_per_batch , device = device ),
60
54
sampler = sampler ,
61
55
batch_size = mini_batch_size ,
62
56
)
@@ -67,6 +61,7 @@ def main(cfg: "DictConfig"): # noqa: F821
67
61
lmbda = cfg .loss .gae_lambda ,
68
62
value_network = critic ,
69
63
average_gae = True ,
64
+ vectorized = not cfg .loss .compile ,
70
65
)
71
66
loss_module = A2CLoss (
72
67
actor_network = actor ,
@@ -83,9 +78,10 @@ def main(cfg: "DictConfig"): # noqa: F821
83
78
# Create optimizer
84
79
optim = torch .optim .Adam (
85
80
loss_module .parameters (),
86
- lr = cfg .optim .lr ,
81
+ lr = torch . tensor ( cfg .optim .lr , device = device ) ,
87
82
weight_decay = cfg .optim .weight_decay ,
88
83
eps = cfg .optim .eps ,
84
+ capturable = device .type == "cuda" ,
89
85
)
90
86
91
87
# Create logger
@@ -115,16 +111,71 @@ def main(cfg: "DictConfig"): # noqa: F821
115
111
)
116
112
test_env .eval ()
117
113
114
+ # update function
115
+ def update (batch , max_grad_norm = cfg .optim .max_grad_norm ):
116
+ # Forward pass A2C loss
117
+ loss = loss_module (batch )
118
+
119
+ loss_sum = loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
120
+
121
+ # Backward pass
122
+ loss_sum .backward ()
123
+ gn = torch .nn .utils .clip_grad_norm_ (
124
+ loss_module .parameters (), max_norm = max_grad_norm
125
+ )
126
+
127
+ # Update the networks
128
+ optim .step ()
129
+ optim .zero_grad (set_to_none = True )
130
+
131
+ return (
132
+ loss .select ("loss_critic" , "loss_entropy" , "loss_objective" )
133
+ .detach ()
134
+ .set ("grad_norm" , gn )
135
+ )
136
+
137
+ if cfg .loss .compile :
138
+ compile_mode = cfg .loss .compile_mode
139
+ if compile_mode in ("" , None ):
140
+ if cfg .loss .cudagraphs :
141
+ compile_mode = None
142
+ else :
143
+ compile_mode = "reduce-overhead"
144
+ update = torch .compile (update , mode = compile_mode )
145
+ actor = torch .compile (actor , mode = compile_mode )
146
+ adv_module = torch .compile (adv_module , mode = compile_mode )
147
+
148
+ if cfg .loss .cudagraphs :
149
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
150
+ actor = CudaGraphModule (actor )
151
+ adv_module = CudaGraphModule (adv_module )
152
+
153
+ # Create collector
154
+ collector = SyncDataCollector (
155
+ create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
156
+ policy = actor ,
157
+ frames_per_batch = frames_per_batch ,
158
+ total_frames = total_frames ,
159
+ device = device ,
160
+ storing_device = device ,
161
+ policy_device = device ,
162
+ )
163
+
118
164
# Main loop
119
165
collected_frames = 0
120
166
num_network_updates = 0
121
167
start_time = time .time ()
122
168
pbar = tqdm .tqdm (total = total_frames )
123
169
num_mini_batches = frames_per_batch // mini_batch_size
124
170
total_network_updates = (total_frames // frames_per_batch ) * num_mini_batches
171
+ lr = cfg .optim .lr
125
172
126
173
sampling_start = time .time ()
127
- for i , data in enumerate (collector ):
174
+ c_iter = iter (collector )
175
+ for i in range (len (collector )):
176
+ with timeit ("collecting" ):
177
+ torch .compiler .cudagraph_mark_step_begin ()
178
+ data = next (c_iter )
128
179
129
180
log_info = {}
130
181
sampling_time = time .time () - sampling_start
@@ -144,59 +195,53 @@ def main(cfg: "DictConfig"): # noqa: F821
144
195
}
145
196
)
146
197
147
- losses = TensorDict ({}, batch_size = [ num_mini_batches ])
198
+ losses = []
148
199
training_start = time .time ()
149
200
150
201
# Compute GAE
151
- with torch .no_grad ():
202
+ with torch .no_grad (), timeit ( "advantage" ) :
152
203
data = adv_module (data )
153
204
data_reshape = data .reshape (- 1 )
154
205
155
206
# Update the data buffer
156
- data_buffer .extend (data_reshape )
157
-
158
- for k , batch in enumerate (data_buffer ):
159
-
160
- # Get a data batch
161
- batch = batch .to (device )
162
-
163
- # Linearly decrease the learning rate and clip epsilon
164
- alpha = 1.0
165
- if cfg .optim .anneal_lr :
166
- alpha = 1 - (num_network_updates / total_network_updates )
167
- for group in optim .param_groups :
168
- group ["lr" ] = cfg .optim .lr * alpha
169
- num_network_updates += 1
170
-
171
- # Forward pass A2C loss
172
- loss = loss_module (batch )
173
- losses [k ] = loss .select (
174
- "loss_critic" , "loss_entropy" , "loss_objective"
175
- ).detach ()
176
- loss_sum = (
177
- loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
178
- )
179
-
180
- # Backward pass
181
- loss_sum .backward ()
182
- torch .nn .utils .clip_grad_norm_ (
183
- list (loss_module .parameters ()), max_norm = cfg .optim .max_grad_norm
184
- )
185
-
186
- # Update the networks
187
- optim .step ()
188
- optim .zero_grad ()
189
-
207
+ with timeit ("emptying" ):
208
+ data_buffer .empty ()
209
+ with timeit ("extending" ):
210
+ data_buffer .extend (data_reshape )
211
+
212
+ with timeit ("optim" ):
213
+ for batch in data_buffer :
214
+
215
+ # Linearly decrease the learning rate and clip epsilon
216
+ with timeit ("optim - lr" ):
217
+ alpha = 1.0
218
+ if cfg .optim .anneal_lr :
219
+ alpha = 1 - (num_network_updates / total_network_updates )
220
+ for group in optim .param_groups :
221
+ group ["lr" ].copy_ (lr * alpha )
222
+
223
+ num_network_updates += 1
224
+
225
+ with timeit ("optim - update" ):
226
+ torch .compiler .cudagraph_mark_step_begin ()
227
+ loss = update (batch )
228
+ losses .append (loss )
229
+
230
+ if i % 200 == 0 :
231
+ timeit .print ()
232
+ timeit .erase ()
190
233
# Get training losses
191
234
training_time = time .time () - training_start
192
- losses = losses .apply (lambda x : x .float ().mean (), batch_size = [])
235
+ losses = torch .stack (losses ).float ().mean ()
236
+
193
237
for key , value in losses .items ():
194
238
log_info .update ({f"train/{ key } " : value .item ()})
195
239
log_info .update (
196
240
{
197
- "train/lr" : alpha * cfg . optim . lr ,
241
+ "train/lr" : lr * alpha ,
198
242
"train/sampling_time" : sampling_time ,
199
243
"train/training_time" : training_time ,
244
+ ** timeit .todict (prefix = "time" ),
200
245
}
201
246
)
202
247
@@ -223,7 +268,6 @@ def main(cfg: "DictConfig"): # noqa: F821
223
268
for key , value in log_info .items ():
224
269
logger .log_scalar (key , value , collected_frames )
225
270
226
- collector .update_policy_weights_ ()
227
271
sampling_start = time .time ()
228
272
229
273
collector .shutdown ()
0 commit comments