Skip to content

Commit 9ae2c28

Browse files
Fix the attention module embedding size (#5272)
* Fix the attention module embedding size * editing the changelog
1 parent 2721989 commit 9ae2c28

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ This results in much less memory being allocated during inference with `CameraSe
6060

6161
#### ml-agents / ml-agents-envs / gym-unity (Python)
6262
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
63+
- The embedding size of attention layers used when a BufferSensor is in the scene has been changed. It is now fixed to 128 units. It might be impossible to resume training from a checkpoint of a previous version. (#5272)
6364

6465
### Bug Fixes
6566
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

ml-agents/mlagents/trainers/tests/torch/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_create_inputs(encoder_type, normalize, num_vector, num_visual):
5151
h_size = 128
5252
obs_spec = create_observation_specs_with_shapes(obs_shapes)
5353
encoders, embedding_sizes = ModelUtils.create_input_processors(
54-
obs_spec, h_size, encoder_type, normalize
54+
obs_spec, h_size, encoder_type, h_size, normalize
5555
)
5656
total_output = sum(embedding_sizes)
5757
vec_enc = []

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333

3434
class ObservationEncoder(nn.Module):
35+
ATTENTION_EMBEDDING_SIZE = 128 # The embedding size of attention is fixed
36+
3537
def __init__(
3638
self,
3739
observation_specs: List[ObservationSpec],
@@ -45,13 +47,17 @@ def __init__(
4547
"""
4648
super().__init__()
4749
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
48-
observation_specs, h_size, vis_encode_type, normalize=normalize
50+
observation_specs,
51+
h_size,
52+
vis_encode_type,
53+
self.ATTENTION_EMBEDDING_SIZE,
54+
normalize=normalize,
4955
)
5056
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
51-
self.processors, self.embedding_sizes, h_size
57+
self.processors, self.embedding_sizes, self.ATTENTION_EMBEDDING_SIZE
5258
)
5359
if self.rsa is not None:
54-
total_enc_size = sum(self.embedding_sizes) + h_size
60+
total_enc_size = sum(self.embedding_sizes) + self.ATTENTION_EMBEDDING_SIZE
5561
else:
5662
total_enc_size = sum(self.embedding_sizes)
5763
self.normalize = normalize
@@ -247,6 +253,8 @@ def forward(
247253

248254

249255
class MultiAgentNetworkBody(torch.nn.Module):
256+
ATTENTION_EMBEDDING_SIZE = 128
257+
250258
"""
251259
A network body that uses a self attention layer to handle state
252260
and action input from a potentially variable number of agents that
@@ -284,13 +292,18 @@ def __init__(
284292
+ sum(self.action_spec.discrete_branches)
285293
+ self.action_spec.continuous_size
286294
)
287-
self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, self.h_size)
288-
self.obs_action_encoder = EntityEmbedding(q_ent_size, None, self.h_size)
289295

290-
self.self_attn = ResidualSelfAttention(self.h_size)
296+
self.obs_encoder = EntityEmbedding(
297+
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
298+
)
299+
self.obs_action_encoder = EntityEmbedding(
300+
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
301+
)
302+
303+
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)
291304

292305
self.linear_encoder = LinearEncoder(
293-
self.h_size,
306+
self.ATTENTION_EMBEDDING_SIZE,
294307
network_settings.num_layers,
295308
self.h_size,
296309
kernel_gain=(0.125 / self.h_size) ** 0.5,
@@ -337,9 +350,7 @@ def _copy_and_remove_nans_from_obs(
337350
no_nan_obs = []
338351
for obs in single_agent_obs:
339352
new_obs = obs.clone()
340-
new_obs[
341-
attention_mask.bool()[:, i_agent], ::
342-
] = 0.0 # Remoove NaNs fast
353+
new_obs[attention_mask.bool()[:, i_agent], ::] = 0.0 # Remove NaNs fast
343354
no_nan_obs.append(new_obs)
344355
obs_with_no_nans.append(no_nan_obs)
345356
return obs_with_no_nans

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,15 @@ def get_encoder_for_obs(
142142
obs_spec: ObservationSpec,
143143
normalize: bool,
144144
h_size: int,
145+
attention_embedding_size: int,
145146
vis_encode_type: EncoderType,
146147
) -> Tuple[nn.Module, int]:
147148
"""
148149
Returns the encoder and the size of the appropriate encoder.
149150
:param shape: Tuples that represent the observation dimension.
150151
:param normalize: Normalize all vector inputs.
151-
:param h_size: Number of hidden units per layer.
152+
:param h_size: Number of hidden units per layer excluding attention layers.
153+
:param attention_embedding_size: Number of hidden units per attention layer.
152154
:param vis_encode_type: Type of visual encoder to use.
153155
"""
154156
shape = obs_spec.shape
@@ -167,7 +169,7 @@ def get_encoder_for_obs(
167169
EntityEmbedding(
168170
entity_size=shape[1],
169171
entity_num_max_elements=shape[0],
170-
embedding_size=h_size,
172+
embedding_size=attention_embedding_size,
171173
),
172174
0,
173175
)
@@ -179,14 +181,16 @@ def create_input_processors(
179181
observation_specs: List[ObservationSpec],
180182
h_size: int,
181183
vis_encode_type: EncoderType,
184+
attention_embedding_size: int,
182185
normalize: bool = False,
183186
) -> Tuple[nn.ModuleList, List[int]]:
184187
"""
185188
Creates visual and vector encoders, along with their normalizers.
186189
:param observation_specs: List of ObservationSpec that represent the observation dimensions.
187190
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
188191
conditioning network on other values (e.g. actions for a Q function)
189-
:param h_size: Number of hidden units per layer.
192+
:param h_size: Number of hidden units per layer excluding attention layers.
193+
:param attention_embedding_size: Number of hidden units per attention layer.
190194
:param vis_encode_type: Type of visual encoder to use.
191195
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
192196
obs.
@@ -200,7 +204,7 @@ def create_input_processors(
200204
embedding_sizes: List[int] = []
201205
for obs_spec in observation_specs:
202206
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
203-
obs_spec, normalize, h_size, vis_encode_type
207+
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
204208
)
205209
encoders.append(encoder)
206210
embedding_sizes.append(embedding_size)
@@ -209,7 +213,7 @@ def create_input_processors(
209213
if x_self_size > 0:
210214
for enc in encoders:
211215
if isinstance(enc, EntityEmbedding):
212-
enc.add_self_embedding(h_size)
216+
enc.add_self_embedding(attention_embedding_size)
213217
return (nn.ModuleList(encoders), embedding_sizes)
214218

215219
@staticmethod

0 commit comments

Comments
 (0)