Skip to content

Commit ff9bd1e

Browse files
author
Ervin T
authored
[poca] Make Observation Encoder a module (#5093)
* Make Observation Encoder a module * Fix copy normalize
1 parent 5d3e500 commit ff9bd1e

File tree

2 files changed

+122
-117
lines changed

2 files changed

+122
-117
lines changed

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 121 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77
from mlagents.trainers.torch.action_model import ActionModel
88
from mlagents.trainers.torch.agent_action import AgentAction
99
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
10-
from mlagents.trainers.settings import NetworkSettings
10+
from mlagents.trainers.settings import NetworkSettings, EncoderType
1111
from mlagents.trainers.torch.utils import ModelUtils
1212
from mlagents.trainers.torch.decoders import ValueHeads
1313
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
1414
from mlagents.trainers.torch.encoders import VectorInput
1515
from mlagents.trainers.buffer import AgentBuffer
1616
from mlagents.trainers.trajectory import ObsUtil
17-
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
17+
from mlagents.trainers.torch.attention import (
18+
EntityEmbedding,
19+
ResidualSelfAttention,
20+
get_zero_entities_mask,
21+
)
22+
from mlagents.trainers.exception import UnityTrainerException
1823

1924

2025
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
@@ -25,6 +30,103 @@
2530
EPSILON = 1e-7
2631

2732

33+
class ObservationEncoder(nn.Module):
34+
def __init__(
35+
self,
36+
observation_specs: List[ObservationSpec],
37+
h_size: int,
38+
vis_encode_type: EncoderType,
39+
normalize: bool = False,
40+
):
41+
"""
42+
Returns an ObservationEncoder that can process and encode a set of observations.
43+
Will use an RSA if needed for variable length observations.
44+
"""
45+
super().__init__()
46+
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
47+
observation_specs, h_size, vis_encode_type, normalize=normalize
48+
)
49+
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
50+
self.processors, self.embedding_sizes, h_size
51+
)
52+
if self.rsa is not None:
53+
total_enc_size = sum(self.embedding_sizes) + h_size
54+
else:
55+
total_enc_size = sum(self.embedding_sizes)
56+
self.normalize = normalize
57+
self._total_enc_size = total_enc_size
58+
59+
@property
60+
def total_enc_size(self) -> int:
61+
"""
62+
Returns the total encoding size for this ObservationEncoder.
63+
"""
64+
return self._total_enc_size
65+
66+
def update_normalization(self, buffer: AgentBuffer) -> None:
67+
obs = ObsUtil.from_buffer(buffer, len(self.processors))
68+
for vec_input, enc in zip(obs, self.processors):
69+
if isinstance(enc, VectorInput):
70+
enc.update_normalization(torch.as_tensor(vec_input))
71+
72+
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
73+
if self.normalize:
74+
for n1, n2 in zip(self.processors, other_encoder.processors):
75+
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
76+
n1.copy_normalization(n2)
77+
78+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
79+
"""
80+
Encode observations using a list of processors and an RSA.
81+
:param inputs: List of Tensors corresponding to a set of obs.
82+
:param processors: a ModuleList of the input processors to be applied to these obs.
83+
:param rsa: Optionally, an RSA to use for variable length obs.
84+
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
85+
"""
86+
encodes = []
87+
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
88+
89+
for idx, processor in enumerate(self.processors):
90+
if not isinstance(processor, EntityEmbedding):
91+
# The input can be encoded without having to process other inputs
92+
obs_input = inputs[idx]
93+
processed_obs = processor(obs_input)
94+
encodes.append(processed_obs)
95+
else:
96+
var_len_processor_inputs.append((processor, inputs[idx]))
97+
if len(encodes) != 0:
98+
encoded_self = torch.cat(encodes, dim=1)
99+
input_exist = True
100+
else:
101+
input_exist = False
102+
if len(var_len_processor_inputs) > 0 and self.rsa is not None:
103+
# Some inputs need to be processed with a variable length encoder
104+
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
105+
embeddings: List[torch.Tensor] = []
106+
processed_self = (
107+
self.x_self_encoder(encoded_self)
108+
if input_exist and self.x_self_encoder is not None
109+
else None
110+
)
111+
for processor, var_len_input in var_len_processor_inputs:
112+
embeddings.append(processor(processed_self, var_len_input))
113+
qkv = torch.cat(embeddings, dim=1)
114+
attention_embedding = self.rsa(qkv, masks)
115+
if not input_exist:
116+
encoded_self = torch.cat([attention_embedding], dim=1)
117+
input_exist = True
118+
else:
119+
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
120+
121+
if not input_exist:
122+
raise UnityTrainerException(
123+
"The trainer was unable to process any of the provided inputs. "
124+
"Make sure the trained agents has at least one sensor attached to them."
125+
)
126+
127+
return encoded_self
128+
129+
28130
class NetworkBody(nn.Module):
29131
def __init__(
30132
self,
@@ -41,22 +143,13 @@ def __init__(
41143
if network_settings.memory is not None
42144
else 0
43145
)
44-
45-
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
146+
self.observation_encoder = ObservationEncoder(
46147
observation_specs,
47148
self.h_size,
48149
network_settings.vis_encode_type,
49-
normalize=self.normalize,
50-
)
51-
52-
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
53-
self.processors, self.embedding_sizes, self.h_size
150+
self.normalize,
54151
)
55-
if self.rsa is not None:
56-
total_enc_size = sum(self.embedding_sizes) + self.h_size
57-
else:
58-
total_enc_size = sum(self.embedding_sizes)
59-
152+
total_enc_size = self.observation_encoder.total_enc_size
60153
total_enc_size += encoded_act_size
61154
self.linear_encoder = LinearEncoder(
62155
total_enc_size, network_settings.num_layers, self.h_size
@@ -68,16 +161,10 @@ def __init__(
68161
self.lstm = None # type: ignore
69162

70163
def update_normalization(self, buffer: AgentBuffer) -> None:
71-
obs = ObsUtil.from_buffer(buffer, len(self.processors))
72-
for vec_input, enc in zip(obs, self.processors):
73-
if isinstance(enc, VectorInput):
74-
enc.update_normalization(torch.as_tensor(vec_input))
164+
self.observation_encoder.update_normalization(buffer)
75165

76166
def copy_normalization(self, other_network: "NetworkBody") -> None:
77-
if self.normalize:
78-
for n1, n2 in zip(self.processors, other_network.processors):
79-
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
80-
n1.copy_normalization(n2)
167+
self.observation_encoder.copy_normalization(other_network.observation_encoder)
81168

82169
@property
83170
def memory_size(self) -> int:
@@ -90,9 +177,7 @@ def forward(
90177
memories: Optional[torch.Tensor] = None,
91178
sequence_length: int = 1,
92179
) -> Tuple[torch.Tensor, torch.Tensor]:
93-
encoded_self = ModelUtils.encode_observations(
94-
inputs, self.processors, self.rsa, self.x_self_encoder
95-
)
180+
encoded_self = self.observation_encoder(inputs)
96181
if actions is not None:
97182
encoded_self = torch.cat([encoded_self, actions], dim=1)
98183
encoding = self.linear_encoder(encoded_self)
@@ -127,27 +212,18 @@ def __init__(
127212
if network_settings.memory is not None
128213
else 0
129214
)
130-
self.processors, _input_size = ModelUtils.create_input_processors(
215+
self.action_spec = action_spec
216+
self.observation_encoder = ObservationEncoder(
131217
observation_specs,
132218
self.h_size,
133219
network_settings.vis_encode_type,
134-
normalize=self.normalize,
135-
)
136-
self.action_spec = action_spec
137-
# This RSA and input are for variable length obs, not for multi-agentt.
138-
(
139-
self.input_rsa,
140-
self.input_x_self_encoder,
141-
) = ModelUtils.create_residual_self_attention(
142-
self.processors, _input_size, self.h_size
220+
self.normalize,
143221
)
144-
if self.input_rsa is not None:
145-
_input_size.append(self.h_size)
146222

147223
# Modules for multi-agent self-attention
148-
obs_only_ent_size = sum(_input_size)
224+
obs_only_ent_size = self.observation_encoder.total_enc_size
149225
q_ent_size = (
150-
sum(_input_size)
226+
obs_only_ent_size
151227
+ sum(self.action_spec.discrete_branches)
152228
+ self.action_spec.continuous_size
153229
)
@@ -173,16 +249,10 @@ def memory_size(self) -> int:
173249
return self.lstm.memory_size if self.use_lstm else 0
174250

175251
def update_normalization(self, buffer: AgentBuffer) -> None:
176-
obs = ObsUtil.from_buffer(buffer, len(self.processors))
177-
for vec_input, enc in zip(obs, self.processors):
178-
if isinstance(enc, VectorInput):
179-
enc.update_normalization(torch.as_tensor(vec_input))
252+
self.observation_encoder.update_normalization(buffer)
180253

181254
def copy_normalization(self, other_network: "MultiAgentNetworkBody") -> None:
182-
if self.normalize:
183-
for n1, n2 in zip(self.processors, other_network.processors):
184-
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
185-
n1.copy_normalization(n2)
255+
self.observation_encoder.copy_normalization(other_network.observation_encoder)
186256

187257
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
188258
"""
@@ -243,9 +313,7 @@ def forward(
243313
obs_attn_mask = self._get_masks_from_nans(obs)
244314
obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
245315
for inputs, action in zip(obs, actions):
246-
encoded = ModelUtils.encode_observations(
247-
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
248-
)
316+
encoded = self.observation_encoder(inputs)
249317
cat_encodes = [
250318
encoded,
251319
action.to_flat(self.action_spec.discrete_branches),
@@ -260,9 +328,7 @@ def forward(
260328
obs_only_attn_mask = self._get_masks_from_nans(obs_only)
261329
obs_only = self._copy_and_remove_nans_from_obs(obs_only, obs_only_attn_mask)
262330
for inputs in obs_only:
263-
encoded = ModelUtils.encode_observations(
264-
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
265-
)
331+
encoded = self.observation_encoder(inputs)
266332
concat_encoded_obs.append(encoded)
267333
g_inp = torch.stack(concat_encoded_obs, dim=1)
268334
self_attn_masks.append(obs_only_attn_mask)
@@ -530,10 +596,10 @@ def forward(
530596
end = 0
531597
vis_index = 0
532598
var_len_index = 0
533-
for i, enc in enumerate(self.network_body.processors):
599+
for i, enc in enumerate(self.network_body.observation_encoder.processors):
534600
if isinstance(enc, VectorInput):
535601
# This is a vec_obs
536-
vec_size = self.network_body.embedding_sizes[i]
602+
vec_size = self.network_body.observation_encoder.embedding_sizes[i]
537603
end = start + vec_size
538604
inputs.append(concatenated_vec_obs[:, start:end])
539605
start = end

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
VectorInput,
1212
)
1313
from mlagents.trainers.settings import EncoderType, ScheduleType
14-
from mlagents.trainers.torch.attention import (
15-
EntityEmbedding,
16-
ResidualSelfAttention,
17-
get_zero_entities_mask,
18-
)
14+
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
1915
from mlagents.trainers.exception import UnityTrainerException
2016
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
2117

@@ -372,63 +368,6 @@ def create_residual_self_attention(
372368
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
373369
return rsa, x_self_encoder
374370

375-
@staticmethod
376-
def encode_observations(
377-
inputs: List[torch.Tensor],
378-
processors: nn.ModuleList,
379-
rsa: Optional[ResidualSelfAttention],
380-
x_self_encoder: Optional[LinearEncoder],
381-
) -> torch.Tensor:
382-
"""
383-
Helper method to encode observations using a listt of processors and an RSA.
384-
:param inputs: List of Tensors corresponding to a set of obs.
385-
:param processors: a ModuleList of the input processors to be applied to these obs.
386-
:param rsa: Optionally, an RSA to use for variable length obs.
387-
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
388-
"""
389-
encodes = []
390-
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
391-
392-
for idx, processor in enumerate(processors):
393-
if not isinstance(processor, EntityEmbedding):
394-
# The input can be encoded without having to process other inputs
395-
obs_input = inputs[idx]
396-
processed_obs = processor(obs_input)
397-
encodes.append(processed_obs)
398-
else:
399-
var_len_processor_inputs.append((processor, inputs[idx]))
400-
if len(encodes) != 0:
401-
encoded_self = torch.cat(encodes, dim=1)
402-
input_exist = True
403-
else:
404-
input_exist = False
405-
if len(var_len_processor_inputs) > 0 and rsa is not None:
406-
# Some inputs need to be processed with a variable length encoder
407-
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
408-
embeddings: List[torch.Tensor] = []
409-
processed_self = (
410-
x_self_encoder(encoded_self)
411-
if input_exist and x_self_encoder is not None
412-
else None
413-
)
414-
for processor, var_len_input in var_len_processor_inputs:
415-
embeddings.append(processor(processed_self, var_len_input))
416-
qkv = torch.cat(embeddings, dim=1)
417-
attention_embedding = rsa(qkv, masks)
418-
if not input_exist:
419-
encoded_self = torch.cat([attention_embedding], dim=1)
420-
input_exist = True
421-
else:
422-
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
423-
424-
if not input_exist:
425-
raise UnityTrainerException(
426-
"The trainer was unable to process any of the provided inputs. "
427-
"Make sure the trained agents has at least one sensor attached to them."
428-
)
429-
430-
return encoded_self
431-
432371
@staticmethod
433372
def trust_region_value_loss(
434373
values: Dict[str, torch.Tensor],

0 commit comments

Comments
 (0)