Skip to content

Commit 52a48a5

Browse files
benbellheronjtatusko
authored andcommitted
Reformat script arguments (#17)
* Reformat script arguments Add support for loading network/opitimizer * update readme
1 parent a641596 commit 52a48a5

13 files changed

+397
-270
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ The log directory contains the tensorboard file, saved models, and other metadat
6666
```
6767
# Local Mode (A2C)
6868
# We recommend 4GB+ GPU memory, 8GB+ RAM, 4+ Cores
69-
python -m adept.scripts.local --env-id BeamRiderNoFrameskip-v4
69+
python -m adept.scripts.local ActorCritic --env-id BeamRiderNoFrameskip-v4
7070
7171
# Towered Mode (A3C Variant, requires mpi4py)
7272
# We recommend 2+ GPUs, 8GB+ GPU memory, 32GB+ RAM, 4+ Cores
73-
python -m adept.scripts.towered --env-id BeamRiderNoFrameskip-v4
73+
python -m adept.scripts.towered ActorCritic --env-id BeamRiderNoFrameskip-v4
7474
7575
# IMPALA (requires mpi4py and is resource intensive)
7676
# We recommend 2+ GPUs, 8GB+ GPU memory, 32GB+ RAM, 4+ Cores
77-
mpiexec -n 3 python -m adept.scripts.impala --env-id BeamRiderNoFrameskip-v4
77+
mpiexec -n 3 python -m adept.scripts.impala ActorCriticVtrace --env-id BeamRiderNoFrameskip-v4
7878
7979
# StarCraft 2 (IMPALA not supported yet)
8080
# Warning: much more resource intensive than Atari

adept/agents/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
You should have received a copy of the GNU General Public License
1515
along with this program. If not, see <http://www.gnu.org/licenses/>.
1616
"""
17+
1718
from .actor_critic import ActorCritic
1819
from .impala import ActorCriticVtrace
1920

adept/agents/actor_critic.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
You should have received a copy of the GNU General Public License
1515
along with this program. If not, see <http://www.gnu.org/licenses/>.
1616
"""
17+
from argparse import ArgumentParser
1718
from collections import OrderedDict
1819

1920
import torch
2021
from adept.environments import Engines
2122
from torch.nn import functional as F
2223

2324
from adept.expcaches.rollout import RolloutCache
24-
from adept.utils.util import listd_to_dlist
25+
from adept.utils.util import listd_to_dlist, parse_bool
2526
from ._base import Agent
2627

2728

@@ -66,6 +67,47 @@ def from_args(cls, network, device, reward_normalizer, gpu_preprocessor, engine,
6667
args.nb_env, args.exp_length, args.discount, args.generalized_advantage_estimation, args.tau, args.normalize_advantage
6768
)
6869

70+
@classmethod
71+
def add_args(cls, parser: ArgumentParser):
72+
parser.add_argument(
73+
'-ae',
74+
'--exp-length',
75+
type=int,
76+
default=20,
77+
help='Experience length (default: 20)'
78+
)
79+
parser.add_argument(
80+
'-ag',
81+
'--generalized-advantage-estimation',
82+
type=parse_bool,
83+
nargs='?',
84+
const=True,
85+
default=True,
86+
help='Use generalized advantage estimation for the policy loss. (default: True)'
87+
)
88+
parser.add_argument(
89+
'-at',
90+
'--tau',
91+
type=float,
92+
default=1.00,
93+
help='parameter for GAE (default: 1.00)'
94+
)
95+
parser.add_argument(
96+
'--entropy-weight',
97+
type=float,
98+
default=0.01,
99+
help='Entropy penalty (default: 0.01)'
100+
)
101+
parser.add_argument(
102+
'--normalize-advantage',
103+
type=parse_bool,
104+
nargs='?',
105+
const=True,
106+
default=False,
107+
help='Normalize the advantage when calculating policy loss. (default: False)'
108+
)
109+
110+
69111
@property
70112
def exp_cache(self):
71113
return self._exp_cache

adept/agents/impala/actor_critic_vtrace.py

+13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
along with this program. If not, see <http://www.gnu.org/licenses/>.
1616
"""
1717
# Use https://github.com/deepmind/scalable_agent/blob/master/vtrace.py for reference
18+
from argparse import ArgumentParser
1819
from collections import OrderedDict
1920
import torch
2021
from torch.nn import functional as F
@@ -54,6 +55,16 @@ def from_args(cls, network, device, reward_normalizer, gpu_preprocessor, engine,
5455
args.nb_env, args.exp_length, args.discount
5556
)
5657

58+
@classmethod
59+
def add_args(cls, parser):
60+
parser.add_argument(
61+
'-ae',
62+
'--exp-length',
63+
type=int,
64+
default=20,
65+
help='Experience length (default: 20)'
66+
)
67+
5768
@property
5869
def exp_cache(self):
5970
return self._exp_cache
@@ -440,3 +451,5 @@ def _vtrace_returns(log_diff_behavior_vs_current, discount_terminal_mask, reward
440451

441452
weighted_advantage = clamped_importance_pg * advantage
442453
return v_s, weighted_advantage, importance
454+
455+

adept/containers/impala.py

+1
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def summary_writer(self):
311311
return self._summary_writer
312312

313313
def run(self, initial_count=0):
314+
self.local_step_count = initial_count
314315
next_obs = self.environment.reset()
315316
self._starting_internals = self.agent.internals
316317
while not self.should_stop():

adept/containers/towered.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def nb_env(self):
301301
return self._nb_env
302302

303303
def run(self, initial_count=0):
304+
self.local_step_count = initial_count
304305
next_obs = self.environment.reset()
305306
self.start_time = time.time()
306307
while not self.should_stop():
@@ -344,7 +345,7 @@ def submit(self):
344345
if host_info is not None:
345346
self.global_step = host_info
346347
else:
347-
self.global_step = 0
348+
self.global_step = self.local_step_count
348349
# host decides when it wants pytorch buffers
349350
if self.mpi_buffer_request.test()[0]:
350351
buffer_list = [x.cpu().numpy() for x in self.network._all_buffers()]

adept/environments/atari.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
import numpy as np
2828

2929

30-
def make_atari_env(env_id, skip_rate, max_ep_length, do_zscore_norm, do_frame_stack, seed):
30+
def make_atari_env(env_id, skip_rate, max_ep_length, do_frame_stack, seed):
3131
def _f():
32-
env = atari_env(env_id, skip_rate, max_ep_length, do_zscore_norm, do_frame_stack, seed)
32+
env = atari_env(env_id, skip_rate, max_ep_length, do_frame_stack, seed)
3333
return env
3434
return _f
3535

3636

37-
def atari_env(env_id, skip_rate, max_ep_length, do_zscore_norm, do_frame_stack, seed):
37+
def atari_env(env_id, skip_rate, max_ep_length, do_frame_stack, seed):
3838
env = gym.make(env_id)
3939
if hasattr(env.unwrapped, 'ale'):
4040
if 'FIRE' in env.unwrapped.get_action_meanings():

adept/scripts/benchmark_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
help='Number of eval steps allowed to run per second decreasing this amount can improve training speed. 0 to disable (default: 0)'
5555
)
5656

57-
args = parser.parse_args()
57+
args = parser.add_args()
5858

5959
args.mode_name = 'Local'
6060

adept/scripts/impala.py

+79-70
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
along with this program. If not, see <http://www.gnu.org/licenses/>.
1717
"""
1818
import os
19+
from copy import deepcopy
1920
from mpi4py import MPI as mpi
2021
import torch
2122
from absl import flags
@@ -41,7 +42,7 @@ def main(args):
4142
if rank == 0:
4243
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
4344
log_id = make_log_id_from_timestamp(args.tag, args.mode_name, args.agent,
44-
args.vision_network + args.network_body,
45+
args.network_vision + args.network_body,
4546
timestamp)
4647
log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)
4748
os.makedirs(log_id_dir)
@@ -53,34 +54,54 @@ def main(args):
5354

5455
if rank != 0:
5556
log_id = make_log_id_from_timestamp(args.tag, args.mode_name, args.agent,
56-
args.vision_network + args.network_body,
57+
args.network_vision + args.network_body,
5758
timestamp)
5859
log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)
5960

6061
comm.Barrier()
6162

6263
# construct env
63-
seed = args.seed if rank == 0 else args.seed + (args.nb_env * (rank - 1)) # unique seed per process
64-
env = make_env(args, seed)
64+
# unique seed per process
65+
seed = args.seed if rank == 0 else args.seed + args.nb_env * (rank - 1)
66+
# don't make a ton of envs if host
67+
if rank == 0:
68+
env_args = deepcopy(args)
69+
env_args.nb_env = 1
70+
env = make_env(env_args, seed)
71+
else:
72+
env = make_env(args, seed)
6573

6674
# construct network
6775
torch.manual_seed(args.seed)
6876
network_head_shapes = get_head_shapes(env.action_space, args.agent)
6977
network = make_network(env.observation_space, network_head_shapes, args)
7078

71-
# sync network params
72-
if rank == 0:
73-
for v in network.parameters():
74-
comm.Bcast(v.detach().cpu().numpy(), root=0)
75-
print('Root variables synced')
79+
# possibly load network
80+
initial_step_count = 0
81+
if args.load_network:
82+
network.load_state_dict(
83+
torch.load(
84+
args.load_network, map_location=lambda storage, loc: storage
85+
)
86+
)
87+
# get step count from network file
88+
epoch_dir = os.path.split(args.load_network)[0]
89+
initial_step_count = int(os.path.split(epoch_dir)[-1])
90+
print('Reloaded network from {}'.format(args.load_network))
91+
# only sync network params if not loading
7692
else:
77-
# can just use the numpy buffers
78-
variables = [v.detach().cpu().numpy() for v in network.parameters()]
79-
for v in variables:
80-
comm.Bcast(v, root=0)
81-
for shared_v, model_v in zip(variables, network.parameters()):
82-
model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True)
83-
print('{} variables synced'.format(rank))
93+
if rank == 0:
94+
for v in network.parameters():
95+
comm.Bcast(v.detach().cpu().numpy(), root=0)
96+
print('Root variables synced')
97+
else:
98+
# can just use the numpy buffers
99+
variables = [v.detach().cpu().numpy() for v in network.parameters()]
100+
for v in variables:
101+
comm.Bcast(v, root=0)
102+
for shared_v, model_v in zip(variables, network.parameters()):
103+
model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True)
104+
print('{} variables synced'.format(rank))
84105

85106
# construct agent
86107
# host is always the first gpu, workers are distributed evenly across the rest
@@ -120,7 +141,7 @@ def main(args):
120141
profiler.stop()
121142
print(profiler.output_text(unicode=True, color=True))
122143
else:
123-
container.run()
144+
container.run(initial_step_count)
124145
env.close()
125146
# host
126147
else:
@@ -136,6 +157,12 @@ def main(args):
136157
# Construct the optimizer
137158
def make_optimizer(params):
138159
opt = torch.optim.RMSprop(params, lr=args.learning_rate, eps=1e-5, alpha=0.99)
160+
if args.load_optimizer:
161+
opt.load_state_dict(
162+
torch.load(
163+
args.load_optimizer, map_location=lambda storage, loc: storage
164+
)
165+
)
139166
return opt
140167

141168
container = ImpalaHost(agent, comm, make_optimizer, summary_writer, args.summary_frequency, saver,
@@ -169,59 +196,41 @@ def make_optimizer(params):
169196
import argparse
170197
from adept.utils.script_helpers import add_base_args, parse_bool
171198

172-
parser = argparse.ArgumentParser(description='AdeptRL IMPALA Mode')
173-
parser = add_base_args(parser)
174-
parser.add_argument('--gpu-id', type=int, nargs='+', default=[0],
175-
help='Which GPU to use for training. The host will always be the first gpu, workers are distributed evenly across the rest (default: [0])')
176-
parser.add_argument(
177-
'-vn', '--vision-network', default='Nature',
178-
help='name of preset network (default: Nature)'
179-
)
180-
parser.add_argument(
181-
'-dn', '--discrete-network', default='Identity',
182-
)
183-
parser.add_argument(
184-
'-nb', '--network-body', default='LSTM',
185-
)
186-
parser.add_argument(
187-
'--agent', default='ActorCriticVtrace',
188-
help='name of preset agent (default: ActorCriticVtrace)'
189-
)
190-
parser.add_argument(
191-
'--profile', type=parse_bool, nargs='?', const=True, default=False,
192-
help='displays profiling tree after 10e3 steps (default: False)'
193-
)
194-
parser.add_argument(
195-
'--debug', type=parse_bool, nargs='?', const=True, default=False,
196-
help='debug mode sends the logs to /tmp/ and overrides number of workers to 3 (default: False)'
197-
)
198-
parser.add_argument(
199-
'--max-queue-length', type=int, default=(size - 1) * 2,
200-
help='Maximum rollout queue length. If above the max, workers will wait to append (default: (size - 1) * 2)'
201-
)
202-
parser.add_argument(
203-
'--num-rollouts-in-batch', type=int, default=(size - 1),
204-
help='The batch size in rollouts (so total batch is this number * nb_env * seq_len). '
205-
+ 'Not compatible with --dynamic-batch (default: (size - 1))'
206-
)
207-
parser.add_argument(
208-
'--max-dynamic-batch', type=int, default=0,
209-
help='When > 0 uses dynamic batching (disables cudnn and --num-rollouts-in-batch). '
210-
+ 'Limits the maximum rollouts in the batch to limit GPU memory usage. (default: 0 (False))'
211-
)
212-
parser.add_argument(
213-
'--min-dynamic-batch', type=int, default=0,
214-
help='Guarantees a minimum number of rollouts in the batch when using dynamic batching. (default: 0)'
215-
)
216-
parser.add_argument(
217-
'--host-training-info-interval', type=int, default=100,
218-
help='The number of training steps before the host writes an info summary. (default: 100)'
219-
)
220-
parser.add_argument(
221-
'--use-local-buffers', type=parse_bool, nargs='?', const=True, default=False,
222-
help='If true all workers use their local network buffers (for batch norm: mean & var are not shared) (default: False)'
223-
)
224-
args = parser.parse_args()
199+
base_parser = argparse.ArgumentParser(description='AdeptRL IMPALA Mode')
200+
201+
def add_args(parser):
202+
parser = parser.add_argument_group('IMPALA Mode Args')
203+
parser.add_argument('--gpu-id', type=int, nargs='+', default=[0],
204+
help='Which GPU to use for training. The host will always be the first gpu, workers are distributed evenly across the rest (default: [0])')
205+
parser.add_argument(
206+
'--max-queue-length', type=int, default=(size - 1) * 2,
207+
help='Maximum rollout queue length. If above the max, workers will wait to append (default: (size - 1) * 2)'
208+
)
209+
parser.add_argument(
210+
'--num-rollouts-in-batch', type=int, default=(size - 1),
211+
help='The batch size in rollouts (so total batch is this number * nb_env * seq_len). '
212+
+ 'Not compatible with --dynamic-batch (default: (size - 1))'
213+
)
214+
parser.add_argument(
215+
'--max-dynamic-batch', type=int, default=0,
216+
help='When > 0 uses dynamic batching (disables cudnn and --num-rollouts-in-batch). '
217+
+ 'Limits the maximum rollouts in the batch to limit GPU memory usage. (default: 0 (False))'
218+
)
219+
parser.add_argument(
220+
'--min-dynamic-batch', type=int, default=0,
221+
help='Guarantees a minimum number of rollouts in the batch when using dynamic batching. (default: 0)'
222+
)
223+
parser.add_argument(
224+
'--host-training-info-interval', type=int, default=100,
225+
help='The number of training steps before the host writes an info summary. (default: 100)'
226+
)
227+
parser.add_argument(
228+
'--use-local-buffers', type=parse_bool, nargs='?', const=True, default=False,
229+
help='If true all workers use their local network buffers (for batch norm: mean & var are not shared) (default: False)'
230+
)
231+
232+
add_base_args(base_parser, add_args)
233+
args = base_parser.parse_args()
225234

226235
if args.debug:
227236
args.nb_env = 3

0 commit comments

Comments
 (0)