-
Notifications
You must be signed in to change notification settings - Fork 4.3k
check min size for visual encoders #3112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -427,6 +438,17 @@ def create_resnet_visual_observation_encoder( | |||
) | |||
return hidden_flat | |||
|
|||
@staticmethod | |||
def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction: | |||
ENCODER_FUNCTION_BY_TYPE = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think I can make this dict a class member :/
enc_func(vis_input, 32, LearningModel.swish, 1, "test", False) | ||
|
||
# Anything under the min size should raise an exception. If not, decrease the min size! | ||
with pytest.raises(Exception): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we except the specific exception (and check that it's the negative dimension issue) just in case the model has a different error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need the exact type; the test is sufficient to make sure that size N works and N-1 doesn't. Besides, it's possible that different encoders would fail for different reasons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
The visual encoders have a minimum usable size due to the strides. Currently, if you go below this size, we get an ugly looking exception in keras.
This PR adds a friendlier exception when you go under the min size, and adds tests to make sure the min is actually the min.