1
- import torch
2
- from torch import nn
1
+ from typing import Tuple , Optional
3
2
3
+ from mlagents .trainers .exception import UnityTrainerException
4
4
5
- class VectorEncoder (nn .Module ):
6
- def __init__ (self , input_size , hidden_size , num_layers , ** kwargs ):
7
- super ().__init__ (** kwargs )
8
- self .layers = [nn .Linear (input_size , hidden_size )]
9
- for _ in range (num_layers - 1 ):
10
- self .layers .append (nn .Linear (hidden_size , hidden_size ))
11
- self .layers .append (nn .ReLU ())
12
- self .seq_layers = nn .Sequential (* self .layers )
13
-
14
- def forward (self , inputs ):
15
- return self .seq_layers (inputs )
5
+ import torch
6
+ from torch import nn
16
7
17
8
18
9
class Normalizer (nn .Module ):
19
- def __init__ (self , vec_obs_size , ** kwargs ):
20
- super ().__init__ (** kwargs )
10
+ def __init__ (self , vec_obs_size : int ):
11
+ super ().__init__ ()
21
12
self .normalization_steps = torch .tensor (1 )
22
13
self .running_mean = torch .zeros (vec_obs_size )
23
14
self .running_variance = torch .ones (vec_obs_size )
24
15
25
- def forward (self , inputs ) :
16
+ def forward (self , inputs : torch . Tensor ) -> torch . Tensor :
26
17
normalized_state = torch .clamp (
27
18
(inputs - self .running_mean )
28
19
/ torch .sqrt (self .running_variance / self .normalization_steps ),
@@ -31,7 +22,7 @@ def forward(self, inputs):
31
22
)
32
23
return normalized_state
33
24
34
- def update (self , vector_input ) :
25
+ def update (self , vector_input : torch . Tensor ) -> None :
35
26
steps_increment = vector_input .size ()[0 ]
36
27
total_new_steps = self .normalization_steps + steps_increment
37
28
@@ -66,14 +57,96 @@ def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
66
57
return h , w
67
58
68
59
69
- def pool_out_shape (h_w , kernel_size ) :
60
+ def pool_out_shape (h_w : Tuple [ int , int ], kernel_size : int ) -> Tuple [ int , int ] :
70
61
height = (h_w [0 ] - kernel_size ) // 2 + 1
71
62
width = (h_w [1 ] - kernel_size ) // 2 + 1
72
63
return height , width
73
64
74
65
66
+ class VectorEncoder (nn .Module ):
67
+ def __init__ (
68
+ self ,
69
+ input_size : int ,
70
+ hidden_size : int ,
71
+ num_layers : int ,
72
+ normalize : bool = False ,
73
+ ):
74
+ self .normalizer : Optional [Normalizer ] = None
75
+ super ().__init__ ()
76
+ self .layers = [nn .Linear (input_size , hidden_size )]
77
+ if normalize :
78
+ self .normalizer = Normalizer (input_size )
79
+
80
+ for _ in range (num_layers - 1 ):
81
+ self .layers .append (nn .Linear (hidden_size , hidden_size ))
82
+ self .layers .append (nn .ReLU ())
83
+ self .seq_layers = nn .Sequential (* self .layers )
84
+
85
+ def forward (self , inputs : torch .Tensor ) -> None :
86
+ if self .normalizer is not None :
87
+ inputs = self .normalizer (inputs )
88
+ return self .seq_layers (inputs )
89
+
90
+ def copy_normalization (self , other_encoder : "VectorEncoder" ) -> None :
91
+ if self .normalizer is not None and other_encoder .normalizer is not None :
92
+ self .normalizer .copy_from (other_encoder .normalizer )
93
+
94
+ def update_normalization (self , inputs : torch .Tensor ) -> None :
95
+ if self .normalizer is not None :
96
+ self .normalizer .update (inputs )
97
+
98
+
99
+ class VectorAndUnnormalizedInputEncoder (VectorEncoder ):
100
+ """
101
+ Encoder for concatenated vector input (can be normalized) and unnormalized vector input.
102
+ This is used for passing inputs to the network that should not be normalized, such as
103
+ actions in the case of a Q function or task parameterizations. It will result in an encoder with
104
+ this structure:
105
+ ____________ ____________ ____________
106
+ | Vector | | Normalize | | Fully |
107
+ | | --> | | --> | Connected | ___________
108
+ |____________| |____________| | | | Output |
109
+ ____________ | | --> | |
110
+ |Unnormalized| | | |___________|
111
+ | Input | ---------------------> | |
112
+ |____________| |____________|
113
+ """
114
+
115
+ def __init__ (
116
+ self ,
117
+ input_size : int ,
118
+ hidden_size : int ,
119
+ unnormalized_input_size : int ,
120
+ num_layers : int ,
121
+ normalize : bool = False ,
122
+ ):
123
+ super ().__init__ (
124
+ input_size + unnormalized_input_size ,
125
+ hidden_size ,
126
+ num_layers ,
127
+ normalize = False ,
128
+ )
129
+ if normalize :
130
+ self .normalizer = Normalizer (input_size )
131
+ else :
132
+ self .normalizer = None
133
+
134
+ def forward ( # pylint: disable=W0221
135
+ self , inputs : torch .Tensor , unnormalized_inputs : Optional [torch .Tensor ] = None
136
+ ) -> None :
137
+ if unnormalized_inputs is None :
138
+ raise UnityTrainerException (
139
+ "Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input."
140
+ ) # Fix mypy errors about method parameters.
141
+ if self .normalizer is not None :
142
+ inputs = self .normalizer (inputs )
143
+ return self .seq_layers (torch .cat ([inputs , unnormalized_inputs ], dim = - 1 ))
144
+
145
+
75
146
class SimpleVisualEncoder (nn .Module ):
76
- def __init__ (self , height , width , initial_channels , output_size ):
147
+ def __init__ (
148
+ self , height : int , width : int , initial_channels : int , output_size : int
149
+ ):
77
150
super ().__init__ ()
78
151
self .h_size = output_size
79
152
conv_1_hw = conv_output_shape ((height , width ), 8 , 4 )
@@ -84,7 +157,7 @@ def __init__(self, height, width, initial_channels, output_size):
84
157
self .conv2 = nn .Conv2d (16 , 32 , [4 , 4 ], [2 , 2 ])
85
158
self .dense = nn .Linear (self .final_flat , self .h_size )
86
159
87
- def forward (self , visual_obs ) :
160
+ def forward (self , visual_obs : torch . Tensor ) -> None :
88
161
conv_1 = torch .relu (self .conv1 (visual_obs ))
89
162
conv_2 = torch .relu (self .conv2 (conv_1 ))
90
163
# hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))
0 commit comments