Skip to content

Commit 6803340

Browse files
authored
Network registry (#39)
* begginings of network_registry * rename to submodule * submodule refinement, network_registry simplification * formatting * modular net init and forward * ModularNetwork.from_args * finish modular_network.from_args, dim to class attribute * update local script * convert basic submodules * simplify spaces, fix submodule output_shapes * bunch of fixes, submodules and identities up to 4d * fix head shapes * fix body not reporting internals * change default net3d to FourConv * towered fixes, some impala fixes * fix impala * copywright notice * starcraft fixes, dtype support, RAM issues? * some ModularNetwork cleanup * progress on validate_shapes * modular network unit tests * fix concat, forward test failing * fix modularnetwork and tests * running
1 parent 541165e commit 6803340

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1886
-750
lines changed

README.md

+7-13
Original file line numberDiff line numberDiff line change
@@ -100,30 +100,24 @@ training frames, since we are interested in sample efficiency.
100100

101101
## API Reference
102102
![architecture](images/architecture.png)
103+
### Containers
104+
Containers hold all of the application state. Each subprocess gets a container
105+
in Towered and IMPALA modes.
103106
### Agents
104107
An Agent acts on and observes the environment.
105108
Currently only ActorCritic is supported. Other agents, such as DQN or ACER may
106109
be added later.
107-
### Containers
108-
Containers hold all of the application state. Each subprocess gets a container
109-
in Towered and IMPALA modes.
110+
### Networks
111+
Networks are not PyTorch modules, they need to implement our abstract
112+
NetworkModule or ModularNetwork classes. A ModularNetwork consists of a
113+
source nets, body, and heads.
110114
### Environments
111115
Environments run in subprocesses and send their observation, rewards,
112116
terminals, and infos to the host process. They work pretty much the same way as
113117
OpenAI's code.
114118
### Experience Caches
115119
An Experience Cache is a Rollout or Experience Replay that is written to after
116120
stepping and read before learning.
117-
### Modules
118-
Modules are generally useful PyTorch modules used in Networks.
119-
### Networks
120-
Networks are not PyTorch modules, they need to implement our abstract
121-
NetworkInterface or ModularNetwork classes. A ModularNetwork consists of a
122-
trunk, body, and head. The Trunk can consist of multiple networks for vision
123-
or discrete data. It flattens these into an embedding. The Body network
124-
operates on the flattened embedding and would typically be an LSTM, Linear
125-
layer, or a combination. The Head depends on the Environment and Agent and is
126-
created accordingly.
127121

128122
## Acknowledgements
129123
We borrow pieces of OpenAI's [gym](https://github.com/openai/gym) and

adept/agents/actor_critic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
)
6666
self._device = device
6767
self.action_space = action_space
68-
self._action_keys = list(sorted(action_space.entries_by_name.keys()))
68+
self._action_keys = list(sorted(action_space.keys()))
6969
self._func_id_to_headnames = None
7070
if self.engine == Engines.SC2:
7171
from adept.environments.deepmind_sc2 import SC2ActionLookup
@@ -112,10 +112,8 @@ def internals(self, new_internals):
112112
self._internals = new_internals
113113

114114
@staticmethod
115-
def output_shape(action_space):
116-
ebn = action_space.entries_by_name
117-
actor_outputs = {name: entry.shape[0] for name, entry in ebn.items()}
118-
head_dict = {'critic': 1, **actor_outputs}
115+
def output_space(action_space):
116+
head_dict = {'critic': (1, ), **action_space}
119117
return head_dict
120118

121119
def act(self, obs):

adept/agents/agent_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def internals(self, new_internals):
6060

6161
@staticmethod
6262
@abc.abstractmethod
63-
def output_shape(action_space):
63+
def output_space(action_space):
6464
raise NotImplementedError
6565

6666
@abc.abstractmethod

adept/agents/agent_registry.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def register_agent(self, agent_id, agent_class):
4747
:return:
4848
"""
4949
assert issubclass(agent_class, AgentModule)
50-
agent_class.check_defaults()
50+
agent_class.check_args_implemented()
5151
self._agent_class_by_id[agent_id] = agent_class
5252

5353
def lookup_agent(self, agent_id):
@@ -59,12 +59,12 @@ def lookup_agent(self, agent_id):
5959
"""
6060
return self._agent_class_by_id[agent_id]
6161

62-
def lookup_output_shape(self, agent_id, action_space):
62+
def lookup_output_space(self, agent_id, action_space):
6363
"""
6464
For a given agent_id, determine the shapes of the outputs.
6565
6666
:param agent_id: str
6767
:param action_space:
6868
:return:
6969
"""
70-
return self._agent_class_by_id[agent_id].output_shape(action_space)
70+
return self._agent_class_by_id[agent_id].output_space(action_space)

adept/agents/impala/actor_critic_vtrace.py

+12-37
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@
1212
#
1313
# You should have received a copy of the GNU General Public License
1414
# along with this program. If not, see <http://www.gnu.org/licenses/>.
15-
# Use https://github.com/deepmind/scalable_agent/blob/master/vtrace.py for reference
1615
from collections import OrderedDict
1716
import torch
1817
from torch.nn import functional as F
1918

2019
from adept.environments.env_registry import Engines
2120
from adept.expcaches.rollout import RolloutCache
2221
from adept.utils.util import listd_to_dlist, dlist_to_listd
23-
from adept.networks._base import ModularNetwork
2422
from adept.agents.agent_module import AgentModule
2523

2624

2725
class ActorCriticVtrace(AgentModule):
26+
"""
27+
Reference implementation:
28+
Use https://github.com/deepmind/scalable_agent/blob/master/vtrace.py
29+
"""
2830
args = {
2931
'nb_rollout': 20,
3032
'discount': 0.99,
@@ -65,7 +67,7 @@ def __init__(
6567
)
6668
self._device = device
6769
self.action_space = action_space
68-
self._action_keys = list(sorted(action_space.entries_by_name.keys()))
70+
self._action_keys = list(sorted(action_space.keys()))
6971
self._func_id_to_headnames = None
7072
if self.engine == Engines.SC2:
7173
from adept.environments.deepmind_sc2 import SC2ActionLookup
@@ -111,11 +113,8 @@ def internals(self, new_internals):
111113
self._internals = new_internals
112114

113115
@staticmethod
114-
def output_shape(action_space):
115-
ebn = action_space.entries_by_name
116-
actor_outputs = {name: entry.shape[0] for name, entry in ebn.items()}
117-
head_dict = {'critic': 1, **actor_outputs}
118-
return head_dict
116+
def output_space(action_space):
117+
return {'critic': (1, ), **action_space}
119118

120119
def seq_obs_to_pathways(self, obs, device):
121120
"""
@@ -287,40 +286,16 @@ def act_on_host(
287286
log_probs_of_action = []
288287
entropies = []
289288

290-
seq_len, batch_size = terminal_masks.shape
291-
292289
# if network is modular,
293290
# trunk can be sped up by combining batch & seq dim
294291
def get_results_generator():
295-
if isinstance(self.network, ModularNetwork):
296-
pathway_dict = self.gpu_preprocessor(obs, self.device)
297-
# flatten obs
298-
flat_obs = {
299-
k: v.view(-1, *v.shape[2:])
300-
for k, v in pathway_dict.items()
301-
}
302-
embeddings = self.network.trunk.forward(flat_obs)
303-
# add back in seq dim
304-
seq_embeddings = embeddings.view(
305-
seq_len, batch_size, embeddings.shape[-1]
306-
)
307-
308-
def get_results(seq_ind, internals):
309-
embedding = seq_embeddings[seq_ind]
310-
pre_result, internals = self.network.body.forward(
311-
embedding, internals
312-
)
313-
return self.network.head.forward(pre_result, internals)
314-
315-
return get_results
316-
else:
317-
obs_on_device = self.seq_obs_to_pathways(obs, self.device)
292+
obs_on_device = self.seq_obs_to_pathways(obs, self.device)
318293

319-
def get_results(seq_ind, internals):
320-
obs_of_seq_ind = obs_on_device[seq_ind]
321-
return self.network(obs_of_seq_ind, internals)
294+
def get_results(seq_ind, internals):
295+
obs_of_seq_ind = obs_on_device[seq_ind]
296+
return self.network(obs_of_seq_ind, internals)
322297

323-
return get_results
298+
return get_results
324299

325300
result_fn = get_results_generator()
326301
for seq_ind in range(terminal_masks.shape[0]):

adept/app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def parse_args():
5555
exit(call(['python', '-m', 'adept.scripts.local'] + argv, env=env))
5656
elif args['<command>'] == 'towered':
5757
nb_mpi_proc = input('Enter number of GPU workers [default: 2]\n')
58-
nb_mpi_proc = 2 if not nb_mpi_proc else nb_mpi_proc
58+
nb_mpi_proc = 2 if not nb_mpi_proc else int(nb_mpi_proc)
5959
exit(call([
6060
'mpiexec',
6161
'-n',
@@ -66,7 +66,7 @@ def parse_args():
6666
] + argv, env=env))
6767
elif args['<command>'] == 'impala':
6868
nb_mpi_proc = input('Enter number of GPU workers [default: 2]\n')
69-
nb_mpi_proc = 2 if not nb_mpi_proc else nb_mpi_proc
69+
nb_mpi_proc = 2 if not nb_mpi_proc else int(nb_mpi_proc)
7070
exit(call([
7171
'mpiexec',
7272
'-n',

adept/environments/_spaces.py

+43-26
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,62 @@
1212
#
1313
# You should have received a copy of the GNU General Public License
1414
# along with this program. If not, see <http://www.gnu.org/licenses/>.
15-
from collections.__init__ import namedtuple
16-
17-
import numpy as np
1815
from gym import spaces
1916

20-
Space = namedtuple('Space', ['shape', 'low', 'high', 'dtype'])
21-
2217

23-
class Spaces:
18+
class Space(dict):
2419
def __init__(self, entries_by_name):
25-
self.entries_by_name = entries_by_name
26-
self.names_by_rank = {1: [], 2: [], 3: [], 4: []}
27-
for name, entry in entries_by_name.items():
28-
self.names_by_rank[len(entry.shape)].append(name)
20+
super(Space, self).__init__(entries_by_name)
2921

3022
@classmethod
3123
def from_gym(cls, gym_space):
32-
entries_by_name = Spaces._detect_gym_spaces(gym_space)
24+
entries_by_name = Space._detect_gym_spaces(gym_space)
3325
return cls(entries_by_name)
3426

3527
@staticmethod
36-
def _detect_gym_spaces(space):
37-
if isinstance(space, spaces.Discrete):
38-
return {'Discrete': Space([space.n], 0, 1, np.float32)}
39-
elif isinstance(space, spaces.MultiDiscrete):
28+
def _detect_gym_spaces(gym_space):
29+
if isinstance(gym_space, spaces.Discrete):
30+
return {'Discrete': (gym_space.n,)}
31+
elif isinstance(gym_space, spaces.MultiDiscrete):
4032
raise NotImplementedError
41-
elif isinstance(space, spaces.MultiBinary):
42-
return {'MultiBinary': Space([space.n], 0, 1, space.dtype)}
43-
elif isinstance(space, spaces.Box):
33+
elif isinstance(gym_space, spaces.MultiBinary):
34+
return {'MultiBinary': (gym_space.n,)}
35+
elif isinstance(gym_space, spaces.Box):
4436
return {
45-
'Box': Space(space.shape, 0., 255., space.dtype)
46-
} # TODO, is it okay to hardcode 0, 255
47-
elif isinstance(space, spaces.Dict):
37+
'Box': gym_space.shape
38+
}
39+
elif isinstance(gym_space, spaces.Dict):
4840
return {
49-
name: list(Spaces._detect_gym_spaces(s).values())[0]
50-
for name, s in space.spaces.items()
41+
name: list(Space._detect_gym_spaces(s).values())[0]
42+
for name, s in gym_space.spaces.items()
5143
}
52-
elif isinstance(space, spaces.Tuple):
44+
elif isinstance(gym_space, spaces.Tuple):
5345
return {
54-
idx: list(Spaces._detect_gym_spaces(s).values())[0]
55-
for idx, s in enumerate(space.spaces)
46+
idx: list(Space._detect_gym_spaces(s).values())[0]
47+
for idx, s in enumerate(gym_space.spaces)
5648
}
49+
50+
@staticmethod
51+
def dtypes_from_gym(gym_space):
52+
if isinstance(gym_space, spaces.Discrete):
53+
return {'Discrete': gym_space.dtype}
54+
elif isinstance(gym_space, spaces.MultiDiscrete):
55+
raise NotImplementedError
56+
elif isinstance(gym_space, spaces.MultiBinary):
57+
return {'MultiBinary': gym_space.dtype}
58+
elif isinstance(gym_space, spaces.Box):
59+
return {
60+
'Box': gym_space.dtype
61+
}
62+
elif isinstance(gym_space, spaces.Dict):
63+
return {
64+
name: Space.dtypes_from_gym(s)
65+
for name, s in gym_space.spaces.items()
66+
}
67+
elif isinstance(gym_space, spaces.Tuple):
68+
return {
69+
idx: Space.dtypes_from_gym(s)
70+
for idx, s in enumerate(gym_space.spaces)
71+
}
72+
else:
73+
raise NotImplementedError

0 commit comments

Comments
 (0)