Skip to content

Commit 38c3dd1

Browse files
author
Ervin T
authored
[refactor] Refactor normalizers and encoders (#4275)
* Refactor normalizers and encoders * Unify Critic and ValueNetwork * Rename ActionVectorEncoder * Update docstring of create_encoders * Add docstring to UnnormalizedInputEncoder
1 parent 4214ec8 commit 38c3dd1

File tree

4 files changed

+176
-164
lines changed

4 files changed

+176
-164
lines changed

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
99
from mlagents.trainers.policy.torch_policy import TorchPolicy
1010
from mlagents.trainers.settings import NetworkSettings
11-
from mlagents.trainers.torch.networks import Critic, QNetwork
11+
from mlagents.trainers.torch.networks import ValueNetwork
1212
from mlagents.trainers.torch.utils import ModelUtils
1313
from mlagents.trainers.buffer import AgentBuffer
1414
from mlagents_envs.timers import timed
@@ -31,11 +31,25 @@ def __init__(
3131
act_size: List[int],
3232
):
3333
super().__init__()
34-
self.q1_network = QNetwork(
35-
stream_names, observation_shapes, network_settings, act_type, act_size
34+
if act_type == ActionType.CONTINUOUS:
35+
num_value_outs = 1
36+
num_action_ins = sum(act_size)
37+
else:
38+
num_value_outs = sum(act_size)
39+
num_action_ins = 0
40+
self.q1_network = ValueNetwork(
41+
stream_names,
42+
observation_shapes,
43+
network_settings,
44+
num_action_ins,
45+
num_value_outs,
3646
)
37-
self.q2_network = QNetwork(
38-
stream_names, observation_shapes, network_settings, act_type, act_size
47+
self.q2_network = ValueNetwork(
48+
stream_names,
49+
observation_shapes,
50+
network_settings,
51+
num_action_ins,
52+
num_value_outs,
3953
)
4054

4155
def forward(
@@ -86,7 +100,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
86100
self.policy.behavior_spec.action_type,
87101
self.act_size,
88102
)
89-
self.target_network = Critic(
103+
self.target_network = ValueNetwork(
90104
self.stream_names,
91105
self.policy.behavior_spec.observation_shapes,
92106
policy_network_settings,
@@ -370,10 +384,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
370384
next_vis_obs.append(next_vis_ob)
371385

372386
# Copy normalizers from policy
373-
self.value_network.q1_network.copy_normalization(
387+
self.value_network.q1_network.network_body.copy_normalization(
374388
self.policy.actor_critic.network_body
375389
)
376-
self.value_network.q2_network.copy_normalization(
390+
self.value_network.q2_network.network_body.copy_normalization(
377391
self.policy.actor_critic.network_body
378392
)
379393
self.target_network.network_body.copy_normalization(

ml-agents/mlagents/trainers/torch/encoders.py

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
1-
import torch
2-
from torch import nn
1+
from typing import Tuple, Optional
32

3+
from mlagents.trainers.exception import UnityTrainerException
44

5-
class VectorEncoder(nn.Module):
6-
def __init__(self, input_size, hidden_size, num_layers, **kwargs):
7-
super().__init__(**kwargs)
8-
self.layers = [nn.Linear(input_size, hidden_size)]
9-
for _ in range(num_layers - 1):
10-
self.layers.append(nn.Linear(hidden_size, hidden_size))
11-
self.layers.append(nn.ReLU())
12-
self.seq_layers = nn.Sequential(*self.layers)
13-
14-
def forward(self, inputs):
15-
return self.seq_layers(inputs)
5+
import torch
6+
from torch import nn
167

178

189
class Normalizer(nn.Module):
19-
def __init__(self, vec_obs_size, **kwargs):
20-
super().__init__(**kwargs)
10+
def __init__(self, vec_obs_size: int):
11+
super().__init__()
2112
self.normalization_steps = torch.tensor(1)
2213
self.running_mean = torch.zeros(vec_obs_size)
2314
self.running_variance = torch.ones(vec_obs_size)
2415

25-
def forward(self, inputs):
16+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
2617
normalized_state = torch.clamp(
2718
(inputs - self.running_mean)
2819
/ torch.sqrt(self.running_variance / self.normalization_steps),
@@ -31,7 +22,7 @@ def forward(self, inputs):
3122
)
3223
return normalized_state
3324

34-
def update(self, vector_input):
25+
def update(self, vector_input: torch.Tensor) -> None:
3526
steps_increment = vector_input.size()[0]
3627
total_new_steps = self.normalization_steps + steps_increment
3728

@@ -66,14 +57,96 @@ def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
6657
return h, w
6758

6859

69-
def pool_out_shape(h_w, kernel_size):
60+
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]:
7061
height = (h_w[0] - kernel_size) // 2 + 1
7162
width = (h_w[1] - kernel_size) // 2 + 1
7263
return height, width
7364

7465

66+
class VectorEncoder(nn.Module):
67+
def __init__(
68+
self,
69+
input_size: int,
70+
hidden_size: int,
71+
num_layers: int,
72+
normalize: bool = False,
73+
):
74+
self.normalizer: Optional[Normalizer] = None
75+
super().__init__()
76+
self.layers = [nn.Linear(input_size, hidden_size)]
77+
if normalize:
78+
self.normalizer = Normalizer(input_size)
79+
80+
for _ in range(num_layers - 1):
81+
self.layers.append(nn.Linear(hidden_size, hidden_size))
82+
self.layers.append(nn.ReLU())
83+
self.seq_layers = nn.Sequential(*self.layers)
84+
85+
def forward(self, inputs: torch.Tensor) -> None:
86+
if self.normalizer is not None:
87+
inputs = self.normalizer(inputs)
88+
return self.seq_layers(inputs)
89+
90+
def copy_normalization(self, other_encoder: "VectorEncoder") -> None:
91+
if self.normalizer is not None and other_encoder.normalizer is not None:
92+
self.normalizer.copy_from(other_encoder.normalizer)
93+
94+
def update_normalization(self, inputs: torch.Tensor) -> None:
95+
if self.normalizer is not None:
96+
self.normalizer.update(inputs)
97+
98+
99+
class VectorAndUnnormalizedInputEncoder(VectorEncoder):
100+
"""
101+
Encoder for concatenated vector input (can be normalized) and unnormalized vector input.
102+
This is used for passing inputs to the network that should not be normalized, such as
103+
actions in the case of a Q function or task parameterizations. It will result in an encoder with
104+
this structure:
105+
____________ ____________ ____________
106+
| Vector | | Normalize | | Fully |
107+
| | --> | | --> | Connected | ___________
108+
|____________| |____________| | | | Output |
109+
____________ | | --> | |
110+
|Unnormalized| | | |___________|
111+
| Input | ---------------------> | |
112+
|____________| |____________|
113+
"""
114+
115+
def __init__(
116+
self,
117+
input_size: int,
118+
hidden_size: int,
119+
unnormalized_input_size: int,
120+
num_layers: int,
121+
normalize: bool = False,
122+
):
123+
super().__init__(
124+
input_size + unnormalized_input_size,
125+
hidden_size,
126+
num_layers,
127+
normalize=False,
128+
)
129+
if normalize:
130+
self.normalizer = Normalizer(input_size)
131+
else:
132+
self.normalizer = None
133+
134+
def forward( # pylint: disable=W0221
135+
self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None
136+
) -> None:
137+
if unnormalized_inputs is None:
138+
raise UnityTrainerException(
139+
"Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input."
140+
) # Fix mypy errors about method parameters.
141+
if self.normalizer is not None:
142+
inputs = self.normalizer(inputs)
143+
return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1))
144+
145+
75146
class SimpleVisualEncoder(nn.Module):
76-
def __init__(self, height, width, initial_channels, output_size):
147+
def __init__(
148+
self, height: int, width: int, initial_channels: int, output_size: int
149+
):
77150
super().__init__()
78151
self.h_size = output_size
79152
conv_1_hw = conv_output_shape((height, width), 8, 4)
@@ -84,7 +157,7 @@ def __init__(self, height, width, initial_channels, output_size):
84157
self.conv2 = nn.Conv2d(16, 32, [4, 4], [2, 2])
85158
self.dense = nn.Linear(self.final_flat, self.h_size)
86159

87-
def forward(self, visual_obs):
160+
def forward(self, visual_obs: torch.Tensor) -> None:
88161
conv_1 = torch.relu(self.conv1(visual_obs))
89162
conv_2 = torch.relu(self.conv2(conv_1))
90163
# hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))

0 commit comments

Comments
 (0)