Skip to content

Commit 47625be

Browse files
author
Chris Elion
authored
check min size for visual encoders (#3112)
* check min size for visual encoders * friendlier exception * fix typo
1 parent ee81d99 commit 47625be

File tree

2 files changed

+81
-13
lines changed

2 files changed

+81
-13
lines changed

ml-agents/mlagents/trainers/models.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
logger = logging.getLogger("mlagents.trainers")
1212

1313
ActivationFunction = Callable[[tf.Tensor], tf.Tensor]
14+
EncoderFunction = Callable[
15+
[tf.Tensor, int, ActivationFunction, int, str, bool], tf.Tensor
16+
]
1417

1518
EPSILON = 1e-7
1619

@@ -26,9 +29,17 @@ class LearningRateSchedule(Enum):
2629
LINEAR = "linear"
2730

2831

29-
class LearningModel(object):
32+
class LearningModel:
3033
_version_number_ = 2
3134

35+
# Minimum supported side for each encoder type. If refactoring an encoder, please
36+
# adjust these also.
37+
MIN_RESOLUTION_FOR_ENCODER = {
38+
EncoderType.SIMPLE: 20,
39+
EncoderType.NATURE_CNN: 36,
40+
EncoderType.RESNET: 15,
41+
}
42+
3243
def __init__(
3344
self, m_size, normalize, use_recurrent, brain, seed, stream_names=None
3445
):
@@ -427,6 +438,17 @@ def create_resnet_visual_observation_encoder(
427438
)
428439
return hidden_flat
429440

441+
@staticmethod
442+
def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction:
443+
ENCODER_FUNCTION_BY_TYPE = {
444+
EncoderType.SIMPLE: LearningModel.create_visual_observation_encoder,
445+
EncoderType.NATURE_CNN: LearningModel.create_nature_cnn_visual_observation_encoder,
446+
EncoderType.RESNET: LearningModel.create_resnet_visual_observation_encoder,
447+
}
448+
return ENCODER_FUNCTION_BY_TYPE.get(
449+
encoder_type, LearningModel.create_visual_observation_encoder
450+
)
451+
430452
@staticmethod
431453
def create_discrete_action_masking_layer(all_logits, action_masks, action_size):
432454
"""
@@ -474,6 +496,17 @@ def create_discrete_action_masking_layer(all_logits, action_masks, action_size):
474496
),
475497
)
476498

499+
@staticmethod
500+
def _check_resolution_for_encoder(
501+
camera_res: CameraResolution, vis_encoder_type: EncoderType
502+
) -> None:
503+
min_res = LearningModel.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
504+
if camera_res.height < min_res or camera_res.width < min_res:
505+
raise UnityTrainerException(
506+
f"Visual observation resolution ({camera_res.width}x{camera_res.height}) is too small for"
507+
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
508+
)
509+
477510
def create_observation_streams(
478511
self,
479512
num_streams: int,
@@ -496,23 +529,20 @@ def create_observation_streams(
496529

497530
self.visual_in = []
498531
for i in range(brain.number_visual_observations):
532+
LearningModel._check_resolution_for_encoder(
533+
brain.camera_resolutions[i], vis_encode_type
534+
)
499535
visual_input = self.create_visual_input(
500536
brain.camera_resolutions[i], name="visual_observation_" + str(i)
501537
)
502538
self.visual_in.append(visual_input)
503539
vector_observation_input = self.create_vector_input()
504540

505-
# Pick the encoder function based on the EncoderType
506-
create_encoder_func = LearningModel.create_visual_observation_encoder
507-
if vis_encode_type == EncoderType.RESNET:
508-
create_encoder_func = LearningModel.create_resnet_visual_observation_encoder
509-
elif vis_encode_type == EncoderType.NATURE_CNN:
510-
create_encoder_func = (
511-
LearningModel.create_nature_cnn_visual_observation_encoder
512-
)
513-
514541
final_hiddens = []
515542
for i in range(num_streams):
543+
# Pick the encoder function based on the EncoderType
544+
create_encoder_func = LearningModel.get_encoder_for_type(vis_encode_type)
545+
516546
visual_encoders = []
517547
hidden_state, hidden_visual = None, None
518548
_scope_add = stream_scopes[i] if stream_scopes else ""
@@ -523,8 +553,8 @@ def create_observation_streams(
523553
h_size,
524554
activation_fn,
525555
num_layers,
526-
scope=f"{_scope_add}main_graph_{i}_encoder{j}",
527-
reuse=False,
556+
f"{_scope_add}main_graph_{i}_encoder{j}", # scope
557+
False, # reuse
528558
)
529559
visual_encoders.append(encoded_visual)
530560
hidden_visual = tf.concat(visual_encoders, axis=1)

ml-agents/mlagents/trainers/tests/test_ppo.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from mlagents.trainers.ppo.models import PPOModel
1010
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
1111
from mlagents.trainers.ppo.policy import PPOPolicy
12-
from mlagents.trainers.brain import BrainParameters
12+
from mlagents.trainers.models import EncoderType, LearningModel
13+
from mlagents.trainers.trainer import UnityTrainerException
14+
from mlagents.trainers.brain import BrainParameters, CameraResolution
1315
from mlagents_envs.environment import UnityEnvironment
1416
from mlagents_envs.mock_communicator import MockCommunicator
1517
from mlagents.trainers.tests import mock_brain as mb
@@ -499,5 +501,41 @@ def test_normalization(dummy_config):
499501
assert (variance[0] - 1) / steps == pytest.approx(0.152, abs=0.01)
500502

501503

504+
def test_min_visual_size():
505+
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER
506+
assert set(LearningModel.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType)
507+
508+
for encoder_type in EncoderType:
509+
with tf.Graph().as_default():
510+
good_size = LearningModel.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
511+
good_res = CameraResolution(
512+
width=good_size, height=good_size, num_channels=3
513+
)
514+
LearningModel._check_resolution_for_encoder(good_res, encoder_type)
515+
vis_input = LearningModel.create_visual_input(
516+
good_res, "test_min_visual_size"
517+
)
518+
enc_func = LearningModel.get_encoder_for_type(encoder_type)
519+
enc_func(vis_input, 32, LearningModel.swish, 1, "test", False)
520+
521+
# Anything under the min size should raise an exception. If not, decrease the min size!
522+
with pytest.raises(Exception):
523+
with tf.Graph().as_default():
524+
bad_size = LearningModel.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1
525+
bad_res = CameraResolution(
526+
width=bad_size, height=bad_size, num_channels=3
527+
)
528+
529+
with pytest.raises(UnityTrainerException):
530+
# Make sure we'd hit a friendly error during model setup time.
531+
LearningModel._check_resolution_for_encoder(bad_res, encoder_type)
532+
533+
vis_input = LearningModel.create_visual_input(
534+
bad_res, "test_min_visual_size"
535+
)
536+
enc_func = LearningModel.get_encoder_for_type(encoder_type)
537+
enc_func(vis_input, 32, LearningModel.swish, 1, "test", False)
538+
539+
502540
if __name__ == "__main__":
503541
pytest.main()

0 commit comments

Comments
 (0)