7
7
from mlagents .trainers .torch .action_model import ActionModel
8
8
from mlagents .trainers .torch .agent_action import AgentAction
9
9
from mlagents .trainers .torch .action_log_probs import ActionLogProbs
10
- from mlagents .trainers .settings import NetworkSettings
10
+ from mlagents .trainers .settings import NetworkSettings , EncoderType
11
11
from mlagents .trainers .torch .utils import ModelUtils
12
12
from mlagents .trainers .torch .decoders import ValueHeads
13
13
from mlagents .trainers .torch .layers import LSTM , LinearEncoder
14
14
from mlagents .trainers .torch .encoders import VectorInput
15
15
from mlagents .trainers .buffer import AgentBuffer
16
16
from mlagents .trainers .trajectory import ObsUtil
17
- from mlagents .trainers .torch .attention import EntityEmbedding , ResidualSelfAttention
17
+ from mlagents .trainers .torch .attention import (
18
+ EntityEmbedding ,
19
+ ResidualSelfAttention ,
20
+ get_zero_entities_mask ,
21
+ )
22
+ from mlagents .trainers .exception import UnityTrainerException
18
23
19
24
20
25
ActivationFunction = Callable [[torch .Tensor ], torch .Tensor ]
25
30
EPSILON = 1e-7
26
31
27
32
33
+ class ObservationEncoder (nn .Module ):
34
+ def __init__ (
35
+ self ,
36
+ observation_specs : List [ObservationSpec ],
37
+ h_size : int ,
38
+ vis_encode_type : EncoderType ,
39
+ normalize : bool = False ,
40
+ ):
41
+ """
42
+ Returns an ObservationEncoder that can process and encode a set of observations.
43
+ Will use an RSA if needed for variable length observations.
44
+ """
45
+ super ().__init__ ()
46
+ self .processors , self .embedding_sizes = ModelUtils .create_input_processors (
47
+ observation_specs , h_size , vis_encode_type , normalize = normalize
48
+ )
49
+ self .rsa , self .x_self_encoder = ModelUtils .create_residual_self_attention (
50
+ self .processors , self .embedding_sizes , h_size
51
+ )
52
+ if self .rsa is not None :
53
+ total_enc_size = sum (self .embedding_sizes ) + h_size
54
+ else :
55
+ total_enc_size = sum (self .embedding_sizes )
56
+ self .normalize = normalize
57
+ self ._total_enc_size = total_enc_size
58
+
59
+ @property
60
+ def total_enc_size (self ) -> int :
61
+ """
62
+ Returns the total encoding size for this ObservationEncoder.
63
+ """
64
+ return self ._total_enc_size
65
+
66
+ def update_normalization (self , buffer : AgentBuffer ) -> None :
67
+ obs = ObsUtil .from_buffer (buffer , len (self .processors ))
68
+ for vec_input , enc in zip (obs , self .processors ):
69
+ if isinstance (enc , VectorInput ):
70
+ enc .update_normalization (torch .as_tensor (vec_input ))
71
+
72
+ def copy_normalization (self , other_encoder : "ObservationEncoder" ) -> None :
73
+ if self .normalize :
74
+ for n1 , n2 in zip (self .processors , other_encoder .processors ):
75
+ if isinstance (n1 , VectorInput ) and isinstance (n2 , VectorInput ):
76
+ n1 .copy_normalization (n2 )
77
+
78
+ def forward (self , inputs : List [torch .Tensor ]) -> torch .Tensor :
79
+ """
80
+ Encode observations using a list of processors and an RSA.
81
+ :param inputs: List of Tensors corresponding to a set of obs.
82
+ :param processors: a ModuleList of the input processors to be applied to these obs.
83
+ :param rsa: Optionally, an RSA to use for variable length obs.
84
+ :param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
85
+ """
86
+ encodes = []
87
+ var_len_processor_inputs : List [Tuple [nn .Module , torch .Tensor ]] = []
88
+
89
+ for idx , processor in enumerate (self .processors ):
90
+ if not isinstance (processor , EntityEmbedding ):
91
+ # The input can be encoded without having to process other inputs
92
+ obs_input = inputs [idx ]
93
+ processed_obs = processor (obs_input )
94
+ encodes .append (processed_obs )
95
+ else :
96
+ var_len_processor_inputs .append ((processor , inputs [idx ]))
97
+ if len (encodes ) != 0 :
98
+ encoded_self = torch .cat (encodes , dim = 1 )
99
+ input_exist = True
100
+ else :
101
+ input_exist = False
102
+ if len (var_len_processor_inputs ) > 0 and self .rsa is not None :
103
+ # Some inputs need to be processed with a variable length encoder
104
+ masks = get_zero_entities_mask ([p_i [1 ] for p_i in var_len_processor_inputs ])
105
+ embeddings : List [torch .Tensor ] = []
106
+ processed_self = (
107
+ self .x_self_encoder (encoded_self )
108
+ if input_exist and self .x_self_encoder is not None
109
+ else None
110
+ )
111
+ for processor , var_len_input in var_len_processor_inputs :
112
+ embeddings .append (processor (processed_self , var_len_input ))
113
+ qkv = torch .cat (embeddings , dim = 1 )
114
+ attention_embedding = self .rsa (qkv , masks )
115
+ if not input_exist :
116
+ encoded_self = torch .cat ([attention_embedding ], dim = 1 )
117
+ input_exist = True
118
+ else :
119
+ encoded_self = torch .cat ([encoded_self , attention_embedding ], dim = 1 )
120
+
121
+ if not input_exist :
122
+ raise UnityTrainerException (
123
+ "The trainer was unable to process any of the provided inputs. "
124
+ "Make sure the trained agents has at least one sensor attached to them."
125
+ )
126
+
127
+ return encoded_self
128
+
129
+
28
130
class NetworkBody (nn .Module ):
29
131
def __init__ (
30
132
self ,
@@ -41,22 +143,13 @@ def __init__(
41
143
if network_settings .memory is not None
42
144
else 0
43
145
)
44
-
45
- self .processors , self .embedding_sizes = ModelUtils .create_input_processors (
146
+ self .observation_encoder = ObservationEncoder (
46
147
observation_specs ,
47
148
self .h_size ,
48
149
network_settings .vis_encode_type ,
49
- normalize = self .normalize ,
50
- )
51
-
52
- self .rsa , self .x_self_encoder = ModelUtils .create_residual_self_attention (
53
- self .processors , self .embedding_sizes , self .h_size
150
+ self .normalize ,
54
151
)
55
- if self .rsa is not None :
56
- total_enc_size = sum (self .embedding_sizes ) + self .h_size
57
- else :
58
- total_enc_size = sum (self .embedding_sizes )
59
-
152
+ total_enc_size = self .observation_encoder .total_enc_size
60
153
total_enc_size += encoded_act_size
61
154
self .linear_encoder = LinearEncoder (
62
155
total_enc_size , network_settings .num_layers , self .h_size
@@ -68,16 +161,10 @@ def __init__(
68
161
self .lstm = None # type: ignore
69
162
70
163
def update_normalization (self , buffer : AgentBuffer ) -> None :
71
- obs = ObsUtil .from_buffer (buffer , len (self .processors ))
72
- for vec_input , enc in zip (obs , self .processors ):
73
- if isinstance (enc , VectorInput ):
74
- enc .update_normalization (torch .as_tensor (vec_input ))
164
+ self .observation_encoder .update_normalization (buffer )
75
165
76
166
def copy_normalization (self , other_network : "NetworkBody" ) -> None :
77
- if self .normalize :
78
- for n1 , n2 in zip (self .processors , other_network .processors ):
79
- if isinstance (n1 , VectorInput ) and isinstance (n2 , VectorInput ):
80
- n1 .copy_normalization (n2 )
167
+ self .observation_encoder .copy_normalization (other_network .observation_encoder )
81
168
82
169
@property
83
170
def memory_size (self ) -> int :
@@ -90,9 +177,7 @@ def forward(
90
177
memories : Optional [torch .Tensor ] = None ,
91
178
sequence_length : int = 1 ,
92
179
) -> Tuple [torch .Tensor , torch .Tensor ]:
93
- encoded_self = ModelUtils .encode_observations (
94
- inputs , self .processors , self .rsa , self .x_self_encoder
95
- )
180
+ encoded_self = self .observation_encoder (inputs )
96
181
if actions is not None :
97
182
encoded_self = torch .cat ([encoded_self , actions ], dim = 1 )
98
183
encoding = self .linear_encoder (encoded_self )
@@ -127,27 +212,18 @@ def __init__(
127
212
if network_settings .memory is not None
128
213
else 0
129
214
)
130
- self .processors , _input_size = ModelUtils .create_input_processors (
215
+ self .action_spec = action_spec
216
+ self .observation_encoder = ObservationEncoder (
131
217
observation_specs ,
132
218
self .h_size ,
133
219
network_settings .vis_encode_type ,
134
- normalize = self .normalize ,
135
- )
136
- self .action_spec = action_spec
137
- # This RSA and input are for variable length obs, not for multi-agentt.
138
- (
139
- self .input_rsa ,
140
- self .input_x_self_encoder ,
141
- ) = ModelUtils .create_residual_self_attention (
142
- self .processors , _input_size , self .h_size
220
+ self .normalize ,
143
221
)
144
- if self .input_rsa is not None :
145
- _input_size .append (self .h_size )
146
222
147
223
# Modules for multi-agent self-attention
148
- obs_only_ent_size = sum ( _input_size )
224
+ obs_only_ent_size = self . observation_encoder . total_enc_size
149
225
q_ent_size = (
150
- sum ( _input_size )
226
+ obs_only_ent_size
151
227
+ sum (self .action_spec .discrete_branches )
152
228
+ self .action_spec .continuous_size
153
229
)
@@ -173,16 +249,10 @@ def memory_size(self) -> int:
173
249
return self .lstm .memory_size if self .use_lstm else 0
174
250
175
251
def update_normalization (self , buffer : AgentBuffer ) -> None :
176
- obs = ObsUtil .from_buffer (buffer , len (self .processors ))
177
- for vec_input , enc in zip (obs , self .processors ):
178
- if isinstance (enc , VectorInput ):
179
- enc .update_normalization (torch .as_tensor (vec_input ))
252
+ self .observation_encoder .update_normalization (buffer )
180
253
181
254
def copy_normalization (self , other_network : "MultiAgentNetworkBody" ) -> None :
182
- if self .normalize :
183
- for n1 , n2 in zip (self .processors , other_network .processors ):
184
- if isinstance (n1 , VectorInput ) and isinstance (n2 , VectorInput ):
185
- n1 .copy_normalization (n2 )
255
+ self .observation_encoder .copy_normalization (other_network .observation_encoder )
186
256
187
257
def _get_masks_from_nans (self , obs_tensors : List [torch .Tensor ]) -> torch .Tensor :
188
258
"""
@@ -243,9 +313,7 @@ def forward(
243
313
obs_attn_mask = self ._get_masks_from_nans (obs )
244
314
obs = self ._copy_and_remove_nans_from_obs (obs , obs_attn_mask )
245
315
for inputs , action in zip (obs , actions ):
246
- encoded = ModelUtils .encode_observations (
247
- inputs , self .processors , self .input_rsa , self .input_x_self_encoder
248
- )
316
+ encoded = self .observation_encoder (inputs )
249
317
cat_encodes = [
250
318
encoded ,
251
319
action .to_flat (self .action_spec .discrete_branches ),
@@ -260,9 +328,7 @@ def forward(
260
328
obs_only_attn_mask = self ._get_masks_from_nans (obs_only )
261
329
obs_only = self ._copy_and_remove_nans_from_obs (obs_only , obs_only_attn_mask )
262
330
for inputs in obs_only :
263
- encoded = ModelUtils .encode_observations (
264
- inputs , self .processors , self .input_rsa , self .input_x_self_encoder
265
- )
331
+ encoded = self .observation_encoder (inputs )
266
332
concat_encoded_obs .append (encoded )
267
333
g_inp = torch .stack (concat_encoded_obs , dim = 1 )
268
334
self_attn_masks .append (obs_only_attn_mask )
@@ -530,10 +596,10 @@ def forward(
530
596
end = 0
531
597
vis_index = 0
532
598
var_len_index = 0
533
- for i , enc in enumerate (self .network_body .processors ):
599
+ for i , enc in enumerate (self .network_body .observation_encoder . processors ):
534
600
if isinstance (enc , VectorInput ):
535
601
# This is a vec_obs
536
- vec_size = self .network_body .embedding_sizes [i ]
602
+ vec_size = self .network_body .observation_encoder . embedding_sizes [i ]
537
603
end = start + vec_size
538
604
inputs .append (concatenated_vec_obs [:, start :end ])
539
605
start = end
0 commit comments