From b5e6d667d28f8dd1ea21ed428fa560b0d4720afa Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 6 Dec 2016 08:46:04 -0800 Subject: [PATCH] [CMake] Fix support for custom kernels in `tf.contrib.metrics`. Fixes #6115. --- tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 2 ++ tensorflow/contrib/cmake/tf_tests.cmake | 1 + tensorflow/contrib/metrics/BUILD | 7 +------ tensorflow/contrib/metrics/python/ops/set_ops.py | 9 +++++---- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 5523023cb7f76e..7f8faff4e58c5c 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -46,6 +46,7 @@ GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contr GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(metrics_set "${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index fcbe238652d94f..cb8845680c5cfe 100644 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -494,6 +494,8 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py) GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_metrics_set_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/metrics/python/ops/gen_set_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 8608d3ff8fb371..cf86419087f97b 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -120,6 +120,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py" "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/metrics/*_test.py" ) # exclude the onces we don't want diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index bee8b9567a3eb1..96c0433f49659a 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -35,12 +35,7 @@ tf_gen_op_libs( tf_gen_op_wrapper_py( name = "set_ops", - hidden = [ - "DenseToDenseSetOperation", - "DenseToSparseSetOperation", - "SparseToSparseSetOperation", - "SetSize", - ], + out = "python/ops/gen_set_ops.py", deps = [":set_ops_op_lib"], ) diff --git a/tensorflow/contrib/metrics/python/ops/set_ops.py b/tensorflow/contrib/metrics/python/ops/set_ops.py index dd737a14c29bf0..31437707418951 100644 --- a/tensorflow/contrib/metrics/python/ops/set_ops.py +++ b/tensorflow/contrib/metrics/python/ops/set_ops.py @@ -19,6 +19,7 @@ from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.contrib.metrics.python.ops import gen_set_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,7 +57,7 @@ def set_size(a, validate_indices=True): if a.values.dtype.base_dtype not in _VALID_DTYPES: raise TypeError("Invalid dtype %s." % a.values.dtype) # pylint: disable=protected-access - return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices) + return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices) ops.NotDifferentiable("SetSize") @@ -100,17 +101,17 @@ def _set_operation(a, b, set_operation, validate_indices=True): # pylint: disable=protected-access if isinstance(a, sparse_tensor.SparseTensor): if isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = _set_ops.sparse_to_sparse_set_operation( + indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( a.indices, a.values, a.shape, b.indices, b.values, b.shape, set_operation, validate_indices) else: raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " "Please flip the order of your inputs.") elif isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = _set_ops.dense_to_sparse_set_operation( + indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( a, b.indices, b.values, b.shape, set_operation, validate_indices) else: - indices, values, shape = _set_ops.dense_to_dense_set_operation( + indices, values, shape = gen_set_ops.dense_to_dense_set_operation( a, b, set_operation, validate_indices) # pylint: enable=protected-access return sparse_tensor.SparseTensor(indices, values, shape)