Skip to content

Commit

Permalink
Fix injecting a _get_expected_image_size when given a Module.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 196956823
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed May 17, 2018
1 parent 648c41e commit f42ba67
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
4 changes: 3 additions & 1 deletion tensorflow_hub/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ py_library(
name = "image_util",
srcs = ["image_util.py"],
srcs_version = "PY2AND3",
deps = [],
deps = [
":module",
],
)

py_test(
Expand Down
16 changes: 11 additions & 5 deletions tensorflow_hub/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

from tensorflow_hub import module


def get_expected_image_size(module_or_spec, signature=None, input_name=None):
"""Returns expected [height, width] dimensions of an image input.
Expand All @@ -35,18 +37,22 @@ def get_expected_image_size(module_or_spec, signature=None, input_name=None):
Raises:
ValueError: If the size information is missing or malformed.
"""
# First try to use a module or spec specific implementation.
# First try to use a spec specific implementation.
#
# Note: this call into _get_expected_image_size is an implementation
# detail to make experimentation easier and suitable to change without
# notice.
if hasattr(module_or_spec, "_get_expected_image_size"):
# pylint: disable=protected-access
image_size = module_or_spec._get_expected_image_size(
# pylint: disable=protected-access
if isinstance(module_or_spec, module.Module):
spec = module_or_spec._spec
else:
spec = module_or_spec
if hasattr(spec, "_get_expected_image_size"):
image_size = spec._get_expected_image_size(
signature=signature, input_name=input_name)
if image_size is not None:
return image_size
# pylint: enable=protected-access
# pylint: enable=protected-access

# Fallback to inspect the input shape in the module signature.
if input_name is None:
Expand Down

0 comments on commit f42ba67

Please sign in to comment.