16
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
17
"""
18
18
import os
19
+ from copy import deepcopy
19
20
from mpi4py import MPI as mpi
20
21
import torch
21
22
from absl import flags
@@ -41,7 +42,7 @@ def main(args):
41
42
if rank == 0 :
42
43
timestamp = datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )
43
44
log_id = make_log_id_from_timestamp (args .tag , args .mode_name , args .agent ,
44
- args .vision_network + args .network_body ,
45
+ args .network_vision + args .network_body ,
45
46
timestamp )
46
47
log_id_dir = os .path .join (args .log_dir , args .env_id , log_id )
47
48
os .makedirs (log_id_dir )
@@ -53,34 +54,54 @@ def main(args):
53
54
54
55
if rank != 0 :
55
56
log_id = make_log_id_from_timestamp (args .tag , args .mode_name , args .agent ,
56
- args .vision_network + args .network_body ,
57
+ args .network_vision + args .network_body ,
57
58
timestamp )
58
59
log_id_dir = os .path .join (args .log_dir , args .env_id , log_id )
59
60
60
61
comm .Barrier ()
61
62
62
63
# construct env
63
- seed = args .seed if rank == 0 else args .seed + (args .nb_env * (rank - 1 )) # unique seed per process
64
- env = make_env (args , seed )
64
+ # unique seed per process
65
+ seed = args .seed if rank == 0 else args .seed + args .nb_env * (rank - 1 )
66
+ # don't make a ton of envs if host
67
+ if rank == 0 :
68
+ env_args = deepcopy (args )
69
+ env_args .nb_env = 1
70
+ env = make_env (env_args , seed )
71
+ else :
72
+ env = make_env (args , seed )
65
73
66
74
# construct network
67
75
torch .manual_seed (args .seed )
68
76
network_head_shapes = get_head_shapes (env .action_space , args .agent )
69
77
network = make_network (env .observation_space , network_head_shapes , args )
70
78
71
- # sync network params
72
- if rank == 0 :
73
- for v in network .parameters ():
74
- comm .Bcast (v .detach ().cpu ().numpy (), root = 0 )
75
- print ('Root variables synced' )
79
+ # possibly load network
80
+ initial_step_count = 0
81
+ if args .load_network :
82
+ network .load_state_dict (
83
+ torch .load (
84
+ args .load_network , map_location = lambda storage , loc : storage
85
+ )
86
+ )
87
+ # get step count from network file
88
+ epoch_dir = os .path .split (args .load_network )[0 ]
89
+ initial_step_count = int (os .path .split (epoch_dir )[- 1 ])
90
+ print ('Reloaded network from {}' .format (args .load_network ))
91
+ # only sync network params if not loading
76
92
else :
77
- # can just use the numpy buffers
78
- variables = [v .detach ().cpu ().numpy () for v in network .parameters ()]
79
- for v in variables :
80
- comm .Bcast (v , root = 0 )
81
- for shared_v , model_v in zip (variables , network .parameters ()):
82
- model_v .data .copy_ (torch .from_numpy (shared_v ), non_blocking = True )
83
- print ('{} variables synced' .format (rank ))
93
+ if rank == 0 :
94
+ for v in network .parameters ():
95
+ comm .Bcast (v .detach ().cpu ().numpy (), root = 0 )
96
+ print ('Root variables synced' )
97
+ else :
98
+ # can just use the numpy buffers
99
+ variables = [v .detach ().cpu ().numpy () for v in network .parameters ()]
100
+ for v in variables :
101
+ comm .Bcast (v , root = 0 )
102
+ for shared_v , model_v in zip (variables , network .parameters ()):
103
+ model_v .data .copy_ (torch .from_numpy (shared_v ), non_blocking = True )
104
+ print ('{} variables synced' .format (rank ))
84
105
85
106
# construct agent
86
107
# host is always the first gpu, workers are distributed evenly across the rest
@@ -120,7 +141,7 @@ def main(args):
120
141
profiler .stop ()
121
142
print (profiler .output_text (unicode = True , color = True ))
122
143
else :
123
- container .run ()
144
+ container .run (initial_step_count )
124
145
env .close ()
125
146
# host
126
147
else :
@@ -136,6 +157,12 @@ def main(args):
136
157
# Construct the optimizer
137
158
def make_optimizer (params ):
138
159
opt = torch .optim .RMSprop (params , lr = args .learning_rate , eps = 1e-5 , alpha = 0.99 )
160
+ if args .load_optimizer :
161
+ opt .load_state_dict (
162
+ torch .load (
163
+ args .load_optimizer , map_location = lambda storage , loc : storage
164
+ )
165
+ )
139
166
return opt
140
167
141
168
container = ImpalaHost (agent , comm , make_optimizer , summary_writer , args .summary_frequency , saver ,
@@ -169,59 +196,41 @@ def make_optimizer(params):
169
196
import argparse
170
197
from adept .utils .script_helpers import add_base_args , parse_bool
171
198
172
- parser = argparse .ArgumentParser (description = 'AdeptRL IMPALA Mode' )
173
- parser = add_base_args (parser )
174
- parser .add_argument ('--gpu-id' , type = int , nargs = '+' , default = [0 ],
175
- help = 'Which GPU to use for training. The host will always be the first gpu, workers are distributed evenly across the rest (default: [0])' )
176
- parser .add_argument (
177
- '-vn' , '--vision-network' , default = 'Nature' ,
178
- help = 'name of preset network (default: Nature)'
179
- )
180
- parser .add_argument (
181
- '-dn' , '--discrete-network' , default = 'Identity' ,
182
- )
183
- parser .add_argument (
184
- '-nb' , '--network-body' , default = 'LSTM' ,
185
- )
186
- parser .add_argument (
187
- '--agent' , default = 'ActorCriticVtrace' ,
188
- help = 'name of preset agent (default: ActorCriticVtrace)'
189
- )
190
- parser .add_argument (
191
- '--profile' , type = parse_bool , nargs = '?' , const = True , default = False ,
192
- help = 'displays profiling tree after 10e3 steps (default: False)'
193
- )
194
- parser .add_argument (
195
- '--debug' , type = parse_bool , nargs = '?' , const = True , default = False ,
196
- help = 'debug mode sends the logs to /tmp/ and overrides number of workers to 3 (default: False)'
197
- )
198
- parser .add_argument (
199
- '--max-queue-length' , type = int , default = (size - 1 ) * 2 ,
200
- help = 'Maximum rollout queue length. If above the max, workers will wait to append (default: (size - 1) * 2)'
201
- )
202
- parser .add_argument (
203
- '--num-rollouts-in-batch' , type = int , default = (size - 1 ),
204
- help = 'The batch size in rollouts (so total batch is this number * nb_env * seq_len). '
205
- + 'Not compatible with --dynamic-batch (default: (size - 1))'
206
- )
207
- parser .add_argument (
208
- '--max-dynamic-batch' , type = int , default = 0 ,
209
- help = 'When > 0 uses dynamic batching (disables cudnn and --num-rollouts-in-batch). '
210
- + 'Limits the maximum rollouts in the batch to limit GPU memory usage. (default: 0 (False))'
211
- )
212
- parser .add_argument (
213
- '--min-dynamic-batch' , type = int , default = 0 ,
214
- help = 'Guarantees a minimum number of rollouts in the batch when using dynamic batching. (default: 0)'
215
- )
216
- parser .add_argument (
217
- '--host-training-info-interval' , type = int , default = 100 ,
218
- help = 'The number of training steps before the host writes an info summary. (default: 100)'
219
- )
220
- parser .add_argument (
221
- '--use-local-buffers' , type = parse_bool , nargs = '?' , const = True , default = False ,
222
- help = 'If true all workers use their local network buffers (for batch norm: mean & var are not shared) (default: False)'
223
- )
224
- args = parser .parse_args ()
199
+ base_parser = argparse .ArgumentParser (description = 'AdeptRL IMPALA Mode' )
200
+
201
+ def add_args (parser ):
202
+ parser = parser .add_argument_group ('IMPALA Mode Args' )
203
+ parser .add_argument ('--gpu-id' , type = int , nargs = '+' , default = [0 ],
204
+ help = 'Which GPU to use for training. The host will always be the first gpu, workers are distributed evenly across the rest (default: [0])' )
205
+ parser .add_argument (
206
+ '--max-queue-length' , type = int , default = (size - 1 ) * 2 ,
207
+ help = 'Maximum rollout queue length. If above the max, workers will wait to append (default: (size - 1) * 2)'
208
+ )
209
+ parser .add_argument (
210
+ '--num-rollouts-in-batch' , type = int , default = (size - 1 ),
211
+ help = 'The batch size in rollouts (so total batch is this number * nb_env * seq_len). '
212
+ + 'Not compatible with --dynamic-batch (default: (size - 1))'
213
+ )
214
+ parser .add_argument (
215
+ '--max-dynamic-batch' , type = int , default = 0 ,
216
+ help = 'When > 0 uses dynamic batching (disables cudnn and --num-rollouts-in-batch). '
217
+ + 'Limits the maximum rollouts in the batch to limit GPU memory usage. (default: 0 (False))'
218
+ )
219
+ parser .add_argument (
220
+ '--min-dynamic-batch' , type = int , default = 0 ,
221
+ help = 'Guarantees a minimum number of rollouts in the batch when using dynamic batching. (default: 0)'
222
+ )
223
+ parser .add_argument (
224
+ '--host-training-info-interval' , type = int , default = 100 ,
225
+ help = 'The number of training steps before the host writes an info summary. (default: 100)'
226
+ )
227
+ parser .add_argument (
228
+ '--use-local-buffers' , type = parse_bool , nargs = '?' , const = True , default = False ,
229
+ help = 'If true all workers use their local network buffers (for batch norm: mean & var are not shared) (default: False)'
230
+ )
231
+
232
+ add_base_args (base_parser , add_args )
233
+ args = base_parser .parse_args ()
225
234
226
235
if args .debug :
227
236
args .nb_env = 3
0 commit comments