Skip to content

Adding a fully connected visual encoder for super small visual input + tests #5351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a fully connected visual encoder for environments with very small image inputs. (#5351)
### Bug Fixes


Expand Down
2 changes: 1 addition & 1 deletion docs/Training-Configuration-File.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ choice of the trainer (which we review on subsequent sections).
| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. `fully_connected` uses a single fully connected dense layer as encoder and should be reserved for very small inputs. |
| `network_settings -> conditioning_type` | (default = `hyper`) Conditioning type for the policy using goal observations. <br><br> `none` treats the goal observations as regular observations, `hyper` (default) uses a HyperNetwork with goal observations as input to generate some of the weights of the policy. Note that when using `hyper` the number of parameters of the network increases greatly. Therefore, it is recommended to reduce the number of `hidden_units` when using this `conditioning_type`


Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def as_dict(self):


class EncoderType(Enum):
FULLY_CONNECTED = "fully_connected"
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
Expand Down
42 changes: 41 additions & 1 deletion ml-agents/mlagents/trainers/tests/torch/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
Expand Down Expand Up @@ -73,7 +75,14 @@ def test_vector_encoder(mock_normalizer):

@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
"vis_class",
[
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
],
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
Expand All @@ -82,3 +91,34 @@ def test_visual_encoder(vis_class, image_size):
sample_input = torch.ones((1, image_size[0], image_size[1], image_size[2]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)


@pytest.mark.parametrize(
"vis_class, size",
[
(SimpleVisualEncoder, 36),
(ResNetVisualEncoder, 36),
(NatureVisualEncoder, 36),
(SmallVisualEncoder, 10),
(FullyConnectedVisualEncoder, 36),
],
)
def test_visual_encoder_trains(vis_class, size):
torch.manual_seed(0)
image_size = (size, size, 1)
batch = 100

inputs = torch.cat(
[torch.zeros((batch,) + image_size), torch.ones((batch,) + image_size)], dim=0
)
target = torch.cat([torch.zeros((batch,)), torch.ones((batch,))], dim=0)
enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)

for _ in range(15):
prediction = enc(inputs)[:, 0]
loss = torch.mean((target - prediction) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
assert loss.item() < 0.05
24 changes: 24 additions & 0 deletions ml-agents/mlagents/trainers/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
self.normalizer.update(inputs)


class FullyConnectedVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.output_size = output_size
self.input_size = height * width * initial_channels
self.dense = nn.Sequential(
linear_layer(
self.input_size,
self.output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use another activation here, but I think ReLU is enough.

)

def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
hidden = visual_obs.reshape(-1, self.input_size)
return self.dense(hidden)


class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
Expand All @@ -20,6 +21,7 @@ class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.FULLY_CONNECTED: 1,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
Expand Down Expand Up @@ -123,6 +125,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

Expand Down