Skip to content

Commit bbdcab1

Browse files
Adding a fully connected visual encoder for super small visual input + tests (#5351)
* initial commit for a fully connected visual encoder * adding a test * addressing comments * Fixing error with minimal size of fully connected network * adding documentation and changelog
1 parent 326a277 commit bbdcab1

File tree

6 files changed

+71
-2
lines changed

6 files changed

+71
-2
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to
1111
### Minor Changes
1212
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
1313
#### ml-agents / ml-agents-envs / gym-unity (Python)
14+
- Added a fully connected visual encoder for environments with very small image inputs. (#5351)
1415
### Bug Fixes
1516

1617

docs/Training-Configuration-File.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ choice of the trainer (which we review on subsequent sections).
4242
| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
4343
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
4444
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
45-
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
45+
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. `fully_connected` uses a single fully connected dense layer as encoder and should be reserved for very small inputs. |
4646
| `network_settings -> conditioning_type` | (default = `hyper`) Conditioning type for the policy using goal observations. <br><br> `none` treats the goal observations as regular observations, `hyper` (default) uses a HyperNetwork with goal observations as input to generate some of the weights of the policy. Note that when using `hyper` the number of parameters of the network increases greatly. Therefore, it is recommended to reduce the number of `hidden_units` when using this `conditioning_type`
4747

4848

ml-agents/mlagents/trainers/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def as_dict(self):
8181

8282

8383
class EncoderType(Enum):
84+
FULLY_CONNECTED = "fully_connected"
8485
MATCH3 = "match3"
8586
SIMPLE = "simple"
8687
NATURE_CNN = "nature_cnn"

ml-agents/mlagents/trainers/tests/torch/test_encoders.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from mlagents.trainers.torch.encoders import (
66
VectorInput,
77
Normalizer,
8+
SmallVisualEncoder,
9+
FullyConnectedVisualEncoder,
810
SimpleVisualEncoder,
911
ResNetVisualEncoder,
1012
NatureVisualEncoder,
@@ -73,7 +75,14 @@ def test_vector_encoder(mock_normalizer):
7375

7476
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
7577
@pytest.mark.parametrize(
76-
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
78+
"vis_class",
79+
[
80+
SimpleVisualEncoder,
81+
ResNetVisualEncoder,
82+
NatureVisualEncoder,
83+
SmallVisualEncoder,
84+
FullyConnectedVisualEncoder,
85+
],
7786
)
7887
def test_visual_encoder(vis_class, image_size):
7988
num_outputs = 128
@@ -82,3 +91,34 @@ def test_visual_encoder(vis_class, image_size):
8291
sample_input = torch.ones((1, image_size[0], image_size[1], image_size[2]))
8392
encoding = enc(sample_input)
8493
assert encoding.shape == (1, num_outputs)
94+
95+
96+
@pytest.mark.parametrize(
97+
"vis_class, size",
98+
[
99+
(SimpleVisualEncoder, 36),
100+
(ResNetVisualEncoder, 36),
101+
(NatureVisualEncoder, 36),
102+
(SmallVisualEncoder, 10),
103+
(FullyConnectedVisualEncoder, 36),
104+
],
105+
)
106+
def test_visual_encoder_trains(vis_class, size):
107+
torch.manual_seed(0)
108+
image_size = (size, size, 1)
109+
batch = 100
110+
111+
inputs = torch.cat(
112+
[torch.zeros((batch,) + image_size), torch.ones((batch,) + image_size)], dim=0
113+
)
114+
target = torch.cat([torch.zeros((batch,)), torch.ones((batch,))], dim=0)
115+
enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
116+
optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)
117+
118+
for _ in range(15):
119+
prediction = enc(inputs)[:, 0]
120+
loss = torch.mean((target - prediction) ** 2)
121+
optimizer.zero_grad()
122+
loss.backward()
123+
optimizer.step()
124+
assert loss.item() < 0.05

ml-agents/mlagents/trainers/torch/encoders.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,30 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
111111
self.normalizer.update(inputs)
112112

113113

114+
class FullyConnectedVisualEncoder(nn.Module):
115+
def __init__(
116+
self, height: int, width: int, initial_channels: int, output_size: int
117+
):
118+
super().__init__()
119+
self.output_size = output_size
120+
self.input_size = height * width * initial_channels
121+
self.dense = nn.Sequential(
122+
linear_layer(
123+
self.input_size,
124+
self.output_size,
125+
kernel_init=Initialization.KaimingHeNormal,
126+
kernel_gain=1.41, # Use ReLU gain
127+
),
128+
nn.LeakyReLU(),
129+
)
130+
131+
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
132+
if not exporting_to_onnx.is_exporting():
133+
visual_obs = visual_obs.permute([0, 3, 1, 2])
134+
hidden = visual_obs.reshape(-1, self.input_size)
135+
return self.dense(hidden)
136+
137+
114138
class SmallVisualEncoder(nn.Module):
115139
"""
116140
CNN architecture used by King in their Candy Crush predictor

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ResNetVisualEncoder,
99
NatureVisualEncoder,
1010
SmallVisualEncoder,
11+
FullyConnectedVisualEncoder,
1112
VectorInput,
1213
)
1314
from mlagents.trainers.settings import EncoderType, ScheduleType
@@ -20,6 +21,7 @@ class ModelUtils:
2021
# Minimum supported side for each encoder type. If refactoring an encoder, please
2122
# adjust these also.
2223
MIN_RESOLUTION_FOR_ENCODER = {
24+
EncoderType.FULLY_CONNECTED: 1,
2325
EncoderType.MATCH3: 5,
2426
EncoderType.SIMPLE: 20,
2527
EncoderType.NATURE_CNN: 36,
@@ -123,6 +125,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
123125
EncoderType.NATURE_CNN: NatureVisualEncoder,
124126
EncoderType.RESNET: ResNetVisualEncoder,
125127
EncoderType.MATCH3: SmallVisualEncoder,
128+
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
126129
}
127130
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
128131

0 commit comments

Comments
 (0)