Skip to content

Commit

Permalink
Make hub.ImageModuleInfo, hub.attach_image_module_info() visible
Browse files Browse the repository at this point in the history
for use by module publishers.

PiperOrigin-RevId: 212239584
  • Loading branch information
TensorFlow Hub Authors authored and arnoegw committed Sep 14, 2018
1 parent e1eabdd commit d3ca254
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow_hub/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":all_protos_py_pb2",
":native_module",
],
)

Expand Down
4 changes: 4 additions & 0 deletions tensorflow_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from tensorflow_hub.estimator import register_module_for_export
from tensorflow_hub.feature_column import image_embedding_column
from tensorflow_hub.feature_column import text_embedding_column
from tensorflow_hub.image_util import attach_image_module_info
from tensorflow_hub.image_util import get_expected_image_size
from tensorflow_hub.image_util import get_num_image_channels
from tensorflow_hub.image_util import ImageModuleInfo
from tensorflow_hub.module import load_module_spec
from tensorflow_hub.module import Module
from tensorflow_hub.module_spec import ModuleSpec
Expand All @@ -53,8 +55,10 @@
"register_module_for_export",
"image_embedding_column",
"text_embedding_column",
"attach_image_module_info",
"get_expected_image_size",
"get_num_image_channels",
"ImageModuleInfo",
"Module",
"ModuleSpec",
"add_signature",
Expand Down
18 changes: 14 additions & 4 deletions tensorflow_hub/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,27 @@
from __future__ import print_function

from tensorflow_hub import image_module_info_pb2
from tensorflow_hub import native_module


# Image modules can provide further information for the utilities in this file
# by attaching an ImageModuleInfo message under this key.
# hub.Modules for images can provide further information for the utilities
# in this file by attaching an ImageModuleInfo message under this key.
IMAGE_MODULE_INFO_KEY = "image_module_info"


def get_image_module_info(module_or_spec):
# The externally visible name of the message is hub.ImageModuleInfo
ImageModuleInfo = image_module_info_pb2.ImageModuleInfo # pylint: disable=invalid-name


def attach_image_module_info(image_module_info):
"""Attaches an ImageModuleInfo message from within a module_fn."""
native_module.attach_message(IMAGE_MODULE_INFO_KEY, image_module_info)


def get_image_module_info(module_or_spec, required=False):
"""Returns the module's attached ImageModuleInfo message, or None."""
return module_or_spec.get_attached_message(
IMAGE_MODULE_INFO_KEY, image_module_info_pb2.ImageModuleInfo)
IMAGE_MODULE_INFO_KEY, ImageModuleInfo, required=required)


def get_expected_image_size(module_or_spec, signature=None, input_name=None):
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_hub/image_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_hub import image_module_info_pb2
from tensorflow_hub import image_util
from tensorflow_hub import module
from tensorflow_hub import native_module
Expand All @@ -39,11 +38,10 @@ def image_module_fn_with_info():
sum_all = tf.reduce_sum(images, [1, 2, 3])
native_module.add_signature(inputs=dict(images=images),
outputs=dict(default=sum_all))
image_module_info = image_module_info_pb2.ImageModuleInfo()
image_module_info = image_util.ImageModuleInfo()
size = image_module_info.default_image_size
size.height, size.width = 2, 4
native_module.attach_message(image_util.IMAGE_MODULE_INFO_KEY,
image_module_info)
image_util.attach_image_module_info(image_module_info)


class ImageModuleTest(tf.test.TestCase):
Expand Down

0 comments on commit d3ca254

Please sign in to comment.