Skip to content

Commit

Permalink
[CMake] Fix support for custom kernels in tf.contrib.metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrry committed Dec 6, 2016
1 parent 81b25af commit b5e6d66
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions tensorflow/contrib/cmake/tf_core_ops.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contr
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(metrics_set "${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc")


########################################################
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/contrib/cmake/tf_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_metrics_set_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/metrics/python/ops/gen_set_ops.py)

add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES})
add_dependencies(tf_python_ops tf_python_op_gen_main)
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/cmake/tf_tests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/training/*_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/metrics/*_test.py"
)

# exclude the onces we don't want
Expand Down
7 changes: 1 addition & 6 deletions tensorflow/contrib/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ tf_gen_op_libs(

tf_gen_op_wrapper_py(
name = "set_ops",
hidden = [
"DenseToDenseSetOperation",
"DenseToSparseSetOperation",
"SparseToSparseSetOperation",
"SetSize",
],
out = "python/ops/gen_set_ops.py",
deps = [":set_ops_op_lib"],
)

Expand Down
9 changes: 5 additions & 4 deletions tensorflow/contrib/metrics/python/ops/set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tensorflow.contrib.framework.python.framework import tensor_util

from tensorflow.contrib.metrics.python.ops import gen_set_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -56,7 +57,7 @@ def set_size(a, validate_indices=True):
if a.values.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("Invalid dtype %s." % a.values.dtype)
# pylint: disable=protected-access
return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices)

ops.NotDifferentiable("SetSize")

Expand Down Expand Up @@ -100,17 +101,17 @@ def _set_operation(a, b, set_operation, validate_indices=True):
# pylint: disable=protected-access
if isinstance(a, sparse_tensor.SparseTensor):
if isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
set_operation, validate_indices)
else:
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
"Please flip the order of your inputs.")
elif isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
a, b.indices, b.values, b.shape, set_operation, validate_indices)
else:
indices, values, shape = _set_ops.dense_to_dense_set_operation(
indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
a, b, set_operation, validate_indices)
# pylint: enable=protected-access
return sparse_tensor.SparseTensor(indices, values, shape)
Expand Down

0 comments on commit b5e6d66

Please sign in to comment.