11
11
logger = logging .getLogger ("mlagents.trainers" )
12
12
13
13
ActivationFunction = Callable [[tf .Tensor ], tf .Tensor ]
14
+ EncoderFunction = Callable [
15
+ [tf .Tensor , int , ActivationFunction , int , str , bool ], tf .Tensor
16
+ ]
14
17
15
18
EPSILON = 1e-7
16
19
@@ -26,9 +29,17 @@ class LearningRateSchedule(Enum):
26
29
LINEAR = "linear"
27
30
28
31
29
- class LearningModel ( object ) :
32
+ class LearningModel :
30
33
_version_number_ = 2
31
34
35
+ # Minimum supported side for each encoder type. If refactoring an encoder, please
36
+ # adjust these also.
37
+ MIN_RESOLUTION_FOR_ENCODER = {
38
+ EncoderType .SIMPLE : 20 ,
39
+ EncoderType .NATURE_CNN : 36 ,
40
+ EncoderType .RESNET : 15 ,
41
+ }
42
+
32
43
def __init__ (
33
44
self , m_size , normalize , use_recurrent , brain , seed , stream_names = None
34
45
):
@@ -427,6 +438,17 @@ def create_resnet_visual_observation_encoder(
427
438
)
428
439
return hidden_flat
429
440
441
+ @staticmethod
442
+ def get_encoder_for_type (encoder_type : EncoderType ) -> EncoderFunction :
443
+ ENCODER_FUNCTION_BY_TYPE = {
444
+ EncoderType .SIMPLE : LearningModel .create_visual_observation_encoder ,
445
+ EncoderType .NATURE_CNN : LearningModel .create_nature_cnn_visual_observation_encoder ,
446
+ EncoderType .RESNET : LearningModel .create_resnet_visual_observation_encoder ,
447
+ }
448
+ return ENCODER_FUNCTION_BY_TYPE .get (
449
+ encoder_type , LearningModel .create_visual_observation_encoder
450
+ )
451
+
430
452
@staticmethod
431
453
def create_discrete_action_masking_layer (all_logits , action_masks , action_size ):
432
454
"""
@@ -474,6 +496,17 @@ def create_discrete_action_masking_layer(all_logits, action_masks, action_size):
474
496
),
475
497
)
476
498
499
+ @staticmethod
500
+ def _check_resolution_for_encoder (
501
+ camera_res : CameraResolution , vis_encoder_type : EncoderType
502
+ ) -> None :
503
+ min_res = LearningModel .MIN_RESOLUTION_FOR_ENCODER [vis_encoder_type ]
504
+ if camera_res .height < min_res or camera_res .width < min_res :
505
+ raise UnityTrainerException (
506
+ f"Visual observation resolution ({ camera_res .width } x{ camera_res .height } ) is too small for"
507
+ f"the provided EncoderType ({ vis_encoder_type .value } ). The min dimension is { min_res } "
508
+ )
509
+
477
510
def create_observation_streams (
478
511
self ,
479
512
num_streams : int ,
@@ -496,23 +529,20 @@ def create_observation_streams(
496
529
497
530
self .visual_in = []
498
531
for i in range (brain .number_visual_observations ):
532
+ LearningModel ._check_resolution_for_encoder (
533
+ brain .camera_resolutions [i ], vis_encode_type
534
+ )
499
535
visual_input = self .create_visual_input (
500
536
brain .camera_resolutions [i ], name = "visual_observation_" + str (i )
501
537
)
502
538
self .visual_in .append (visual_input )
503
539
vector_observation_input = self .create_vector_input ()
504
540
505
- # Pick the encoder function based on the EncoderType
506
- create_encoder_func = LearningModel .create_visual_observation_encoder
507
- if vis_encode_type == EncoderType .RESNET :
508
- create_encoder_func = LearningModel .create_resnet_visual_observation_encoder
509
- elif vis_encode_type == EncoderType .NATURE_CNN :
510
- create_encoder_func = (
511
- LearningModel .create_nature_cnn_visual_observation_encoder
512
- )
513
-
514
541
final_hiddens = []
515
542
for i in range (num_streams ):
543
+ # Pick the encoder function based on the EncoderType
544
+ create_encoder_func = LearningModel .get_encoder_for_type (vis_encode_type )
545
+
516
546
visual_encoders = []
517
547
hidden_state , hidden_visual = None , None
518
548
_scope_add = stream_scopes [i ] if stream_scopes else ""
@@ -523,8 +553,8 @@ def create_observation_streams(
523
553
h_size ,
524
554
activation_fn ,
525
555
num_layers ,
526
- scope = f"{ _scope_add } main_graph_{ i } _encoder{ j } " ,
527
- reuse = False ,
556
+ f"{ _scope_add } main_graph_{ i } _encoder{ j } " , # scope
557
+ False , # reuse
528
558
)
529
559
visual_encoders .append (encoded_visual )
530
560
hidden_visual = tf .concat (visual_encoders , axis = 1 )
0 commit comments