From fdea17d8b449cbee9719ab4022a24e2d9918c25f Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 11 Oct 2016 08:42:35 -0800 Subject: [PATCH] Store SparseTensors in a Map inside a container for Queue round-trip. This is much more efficient than serializing the underlying Tensors to strings and dserializing them on the other side. Instead we pass through the keys to the SparseTensors inside the Map. Methods are kept private for use by queueing wrappers. Includes benchmarks that show wall-time is almost 50% of the wall-time of using the sparse serialization/deserialization wrappers: I1003 17:24:34.355306 18675 benchmark.py:77] Benchmark [BenchmarkSparseTensorsMapVsSerialization.benchmark_very_large_2d_float_st_tensor_maps] iters: 2000, wall_time: 0.00260997, cpu_time: -1,throughput: -1 I1003 17:24:42.735983 18675 benchmark.py:77] Benchmark [BenchmarkSparseTensorsMapVsSerialization.benchmark_vey_large_2d_float_st_serialization] iters: 2000, wall_time: 0.00415492, cpu_time: -1,throughput: -1 *** Update: After updates to sparse_tensor.h's concat code (pushed in a sister PR), there's a speedup in both benchmarks: I1004 09:39:30.630354 24400 benchmark.py:77] Benchmark [BenchmarkSparseTensorsMapVsSerialization.benchmark_very_large_2d_float_st_tensor_maps] iters: 2000, wall_time: 0.0022105 I1004 09:39:38.125391 24400 benchmark.py:77] Benchmark [BenchmarkSparseTensorsMapVsSerialization.benchmark_very_large_2d_float_st_serialization] iters: 2000, wall_time: 0.00372696 *** Update 2: After properly placed std::moves in the sparse_tensors_map code, that benchmark is now faster: Benchmark [BenchmarkSparseTensorsMapVsSerialization.benchmark_very_large_2d_float_st_tensor_maps] iters: 2000, wall_time: 0.00187492 Total speedup is now: 0.00415492 / 0.00187492 = 2.2x Change: 135805924 --- tensorflow/core/kernels/BUILD | 1 + .../core/kernels/sparse_tensors_map_ops.cc | 494 ++++++++++++++++++ tensorflow/core/ops/sparse_ops.cc | 185 +++++++ tensorflow/core/ops/sparse_ops_test.cc | 35 ++ tensorflow/python/kernel_tests/BUILD | 11 +- .../sparse_tensors_map_ops_test.py | 234 +++++++++ tensorflow/python/ops/hidden_ops.txt | 3 + tensorflow/python/ops/sparse_ops.py | 147 ++++++ 8 files changed, 1109 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/kernels/sparse_tensors_map_ops.cc create mode 100644 tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6f9f618217d121..ba07cedee07155 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1867,6 +1867,7 @@ tf_kernel_libraries( "sparse_to_dense_op", "sparse_xent_op", "serialize_sparse_op", + "sparse_tensors_map_ops", ], deps = [ ":bounds_check", diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc new file mode 100644 index 00000000000000..5673ab4ee5bbfc --- /dev/null +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc @@ -0,0 +1,494 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +using sparse::SparseTensor; + +class SparseTensorsMap : public ResourceBase { + public: + SparseTensorsMap(const string& name) : name_(name), counter_(0) {} + + string DebugString() override { return "A SparseTensorsMap"; } + + typedef struct { + PersistentTensor indices; + PersistentTensor values; + TensorShape shape; + } PersistentSparseTensor; + + Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, + int64* handle) { + PersistentTensor persistent_ix; + Tensor* ix; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix)); + *ix = sp.indices(); + + PersistentTensor persistent_values; + Tensor* values; + TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(), + sp.indices().shape(), + &persistent_values, &values)); + *values = sp.values(); + { + mutex_lock l(mu_); + int64 unique_st_handle = counter_++; // increment is guarded on purpose + sp_tensors_[unique_st_handle] = + PersistentSparseTensor{persistent_ix, persistent_values, sp.shape()}; + *handle = unique_st_handle; + } + return Status::OK(); + } + + Status RetrieveAndClearSparseTensors( + OpKernelContext* ctx, const TTypes::ConstVec& handles, + std::vector* sparse_tensors) { + sparse_tensors->clear(); + sparse_tensors->reserve(handles.size()); + { + mutex_lock l(mu_); + for (size_t i = 0; i < handles.size(); ++i) { + const int64 handle = handles(i); + auto sp_iter = sp_tensors_.find(handle); + if (sp_iter == sp_tensors_.end()) { + return errors::InvalidArgument("Unable to find SparseTensor: ", + handle, " in map: ", name_); + } + const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx); + const Tensor* values = sp_iter->second.values.AccessTensor(ctx); + const TensorShape& shape = sp_iter->second.shape; + sparse_tensors->emplace_back(*ix, *values, shape); + + sp_tensors_.erase(sp_iter); + } + } + + return Status::OK(); + } + + protected: + ~SparseTensorsMap() override {} + + private: + string name_; + + mutex mu_; + int64 counter_ GUARDED_BY(mu_); + std::unordered_map sp_tensors_ GUARDED_BY(mu_); +}; + +class SparseTensorAccessingOp : public OpKernel { + public: + typedef std::function CreatorCallback; + + SparseTensorAccessingOp(OpKernelConstruction* context) + : OpKernel(context), sparse_tensors_map_(nullptr) {} + + protected: + ~SparseTensorAccessingOp() { + if (sparse_tensors_map_) sparse_tensors_map_->Unref(); + } + + Status GetMap(OpKernelContext* ctx, bool is_writing, + SparseTensorsMap** sparse_tensors_map) { + mutex_lock l(mu_); + + if (sparse_tensors_map_) { + *sparse_tensors_map = sparse_tensors_map_; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), + is_writing /* use_node_name_as_default */)); + + CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { + SparseTensorsMap *map = new SparseTensorsMap(cinfo_.name()); + *c = map; + return Status::OK(); + }; + + TF_RETURN_IF_ERROR( + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &sparse_tensors_map_, + sparse_tensors_map_creator)); + + *sparse_tensors_map = sparse_tensors_map_; + return Status::OK(); + } + + private: + ContainerInfo cinfo_; + + mutex mu_; + SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_); +}; + +class AddSparseToTensorsMapOp : public SparseTensorAccessingOp { + public: + explicit AddSparseToTensorsMapOp(OpKernelConstruction* context) + : SparseTensorAccessingOp(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor* input_indices; + const Tensor* input_values; + const Tensor* input_shape; + SparseTensorsMap* map; + + OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); + OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); + OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); + OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + input_indices->shape().DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), + errors::InvalidArgument( + "Input values should be a vector but received shape ", + input_values->shape().DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape ", + input_shape->shape().DebugString())); + + TensorShape input_shape_object; + OP_REQUIRES_OK(context, + TensorShapeUtils::MakeShape(input_shape->vec().data(), + input_shape->NumElements(), + &input_shape_object)); + SparseTensor st(*input_indices, *input_values, input_shape_object); + int64 handle; + OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle)); + + Tensor sparse_handle(DT_INT64, TensorShape({})); + auto sparse_handle_t = sparse_handle.scalar(); + + sparse_handle_t() = handle; + + context->set_output(0, sparse_handle); + } +}; + +REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU), + AddSparseToTensorsMapOp); + +template +class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp { + public: + explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context) + : SparseTensorAccessingOp(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor* input_indices; + const Tensor* input_values; + const Tensor* input_shape; + SparseTensorsMap* map; + + OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); + OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); + OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); + OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + input_indices->shape().DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), + errors::InvalidArgument( + "Input values should be a vector but received shape ", + input_values->shape().DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape ", + input_shape->shape().DebugString())); + + int rank = input_shape->NumElements(); + + OP_REQUIRES( + context, rank > 1, + errors::InvalidArgument( + "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); + + TensorShape tensor_input_shape(input_shape->vec()); + gtl::InlinedVector std_order(rank); + std::iota(std_order.begin(), std_order.end(), 0); + SparseTensor input_st(*input_indices, *input_values, tensor_input_shape, + std_order); + + auto input_shape_t = input_shape->vec(); + const int64 N = input_shape_t(0); + + Tensor sparse_handles(DT_INT64, TensorShape({N})); + auto sparse_handles_t = sparse_handles.vec(); + + OP_REQUIRES_OK(context, input_st.IndicesValid()); + + // We can generate the output shape proto string now, for all + // minibatch entries. + TensorShape output_shape; + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + input_shape_t.data() + 1, + input_shape->NumElements() - 1, &output_shape)); + + // Get groups by minibatch dimension + std::unordered_set visited; + sparse::GroupIterable minibatch = input_st.group({0}); + for (const auto& subset : minibatch) { + const int64 b = subset.group()[0]; + visited.insert(b); + OP_REQUIRES( + context, b > -1 && b < N, + errors::InvalidArgument( + "Received unexpected column 0 value in input SparseTensor: ", b, + " < 0 or >= N (= ", N, ")")); + + const auto indices = subset.indices(); + const auto values = subset.values(); + const int64 num_entries = values.size(); + + Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); + Tensor output_values = Tensor(DataTypeToEnum::value, {num_entries}); + + auto output_indices_t = output_indices.matrix(); + auto output_values_t = output_values.vec(); + + for (int i = 0; i < num_entries; ++i) { + for (int d = 1; d < rank; ++d) { + output_indices_t(i, d - 1) = indices(i, d); + } + output_values_t(i) = values(i); + } + + SparseTensor st_i(output_indices, output_values, output_shape); + int64 handle; + OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle)); + sparse_handles_t(b) = handle; + } + + // Fill in any gaps; we must provide an empty ST for batch entries + // the grouper didn't find. + if (visited.size() < N) { + Tensor empty_indices(DT_INT64, {0, rank - 1}); + Tensor empty_values(DataTypeToEnum::value, {0}); + SparseTensor empty_st(empty_indices, empty_values, output_shape); + + for (int64 b = 0; b < N; ++b) { + // We skipped this batch entry. + if (visited.find(b) == visited.end()) { + int64 handle; + OP_REQUIRES_OK(context, + map->AddSparseTensor(context, empty_st, &handle)); + sparse_handles_t(b) = handle; + } + } + } + + context->set_output(0, sparse_handles); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + AddManySparseToTensorsMapOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +template +class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { + public: + explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context) + : SparseTensorAccessingOp(context) {} + + void Compute(OpKernelContext* context) override { + SparseTensorsMap* map; + OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map)); + + const Tensor& sparse_handles = context->input(0); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()), + errors::InvalidArgument( + "sparse_handles should be a vector but received shape ", + sparse_handles.shape().DebugString())); + + int64 N = sparse_handles.shape().dim_size(0); + + OP_REQUIRES( + context, N > 0, + errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " + "but input matrix has 0 rows")); + + std::vector indices_to_concat; + std::vector values_to_concat; + std::vector shapes_to_concat; + + const auto& sparse_handles_t = sparse_handles.vec(); + + std::vector sparse_tensors; + + OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors( + context, sparse_handles_t, &sparse_tensors)); + + for (int64 i = 0; i < N; ++i) { + const SparseTensor& st = sparse_tensors[i]; + const Tensor& output_indices = st.indices(); + const Tensor& output_values = st.values(); + const TensorShape& output_shape = st.shape(); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), + errors::InvalidArgument( + "Expected sparse_handles[", i, + "] to represent an index matrix but received shape ", + output_indices.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), + errors::InvalidArgument( + "Expected sparse_handles[", i, + "] to represent a values vector but received shape ", + output_values.shape().DebugString())); + OP_REQUIRES( + context, DataTypeToEnum::value == output_values.dtype(), + errors::InvalidArgument( + "Requested SparseTensor of type ", + DataTypeString(DataTypeToEnum::value), " but SparseTensor[", i, + "].values.dtype() == ", DataTypeString(output_values.dtype()))); + + int64 num_entries = output_indices.dim_size(0); + OP_REQUIRES(context, num_entries == output_values.dim_size(0), + errors::InvalidArgument( + "Expected row counts of SparseTensor[", i, + "].indices and SparseTensor[", i, + "].values to match but they do not: ", num_entries, + " vs. ", output_values.dim_size(0))); + int rank = output_indices.dim_size(1); + OP_REQUIRES( + context, rank == output_shape.dims(), + errors::InvalidArgument("Expected column counts of SparseTensor[", i, + "].indices to match size of SparseTensor[", i, + "].shape " + "but they do not: ", + rank, " vs. ", output_shape.dims())); + + // Now we expand each SparseTensors' indices and shape by + // prefixing a dimension + Tensor expanded_indices( + DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); + Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); + const auto& output_indices_t = output_indices.matrix(); + auto expanded_indices_t = expanded_indices.matrix(); + auto expanded_shape_t = expanded_shape.vec(); + expanded_indices_t.chip<1>(0).setZero(); + Eigen::DSizes indices_start(0, 1); + Eigen::DSizes indices_sizes(num_entries, rank); + expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; + expanded_shape_t(0) = 1; + // TODO: copy shape from TensorShape to &expanded_shape_t(1) + // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); + for (int i = 0; i < rank; ++i) { + expanded_shape_t(i + 1) = output_shape.dim_size(i); + } + TensorShape expanded_tensor_shape(expanded_shape_t); + + indices_to_concat.push_back(std::move(expanded_indices)); + values_to_concat.push_back(output_values); + shapes_to_concat.push_back(std::move(expanded_tensor_shape)); + } + + int rank = -1; + for (int i = 0; i < N; ++i) { + if (rank < 0) rank = shapes_to_concat[i].dims(); + OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), + errors::InvalidArgument( + "Inconsistent rank across SparseTensors: rank prior to " + "SparseTensor[", + i, "] was: ", rank, " but rank of SparseTensor[", i, + "] is: ", shapes_to_concat[i].dims())); + } + + // SparseTensor::Concat requires consistent shape for all but the + // primary order dimension (dimension 0 in this case). So we get + // the maximum value across all the input SparseTensors for each + // dimension and use that. + TensorShape preconcat_shape(shapes_to_concat[0]); + for (int i = 0; i < N; ++i) { + for (int d = 0; d < rank; ++d) { + preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), + shapes_to_concat[i].dim_size(d))); + } + } + + // Dimension 0 is the primary dimension. + gtl::InlinedVector std_order(rank); + std::iota(std_order.begin(), std_order.end(), 0); + + std::vector tensors_to_concat; + for (int i = 0; i < N; ++i) { + tensors_to_concat.emplace_back(std::move(indices_to_concat[i]), + std::move(values_to_concat[i]), + preconcat_shape, std_order); + } + + SparseTensor output(SparseTensor::Concat(tensors_to_concat)); + + Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); + + std::copy_n(output.shape().dim_sizes().data(), output.dims(), + final_output_shape.vec().data()); + + context->set_output(0, output.indices()); + context->set_output(1, output.values()); + context->set_output(2, final_output_shape); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + TakeManySparseFromTensorsMapOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index a1d6b648e7abde..6372aea95a0701 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -860,4 +860,189 @@ output_indices: 2-D. The indices of the output SparseTensor. output_values: 1-D. The values of the output SparseTensor. )doc"); +REGISTER_OP("AddSparseToTensorsMap") + .Input("sparse_indices: int64") + .Input("sparse_values: T") + .Input("sparse_shape: int64") + .Output("sparse_handle: int64") + .Attr("T: type") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Add a `SparseTensor` to a `SparseTensorsMap` return its handle. + +A `SparseTensor` is represented by three tensors: `sparse_indices`, +`sparse_values`, and `sparse_shape`. + +This operator takes the given `SparseTensor` and adds it to a container +object (a `SparseTensorsMap`). A unique key within this container is generated +in the form of an `int64`, and this is the value that is returned. + +The `SparseTensor` can then be read out as part of a minibatch by passing +the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure +the correct `SparseTensorsMap` is accessed, ensure that the same +`container` and `shared_name` are passed to that Op. If no `shared_name` +is provided here, instead use the *name* of the Operation created by calling +`AddSparseToTensorsMap` as the `shared_name` passed to +`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. + +sparse_indices: 2-D. The `indices` of the `SparseTensor`. +sparse_values: 1-D. The `values` of the `SparseTensor`. +sparse_shape: 1-D. The `shape` of the `SparseTensor`. +sparse_handle: 0-D. The handle of the `SparseTensor` now stored in the + `SparseTensorsMap`. +container: The container name for the `SparseTensorsMap` created by this op. +shared_name: The shared name for the `SparseTensorsMap` created by this op. + If blank, the new Operation's unique name is used. +)doc"); + +REGISTER_OP("AddManySparseToTensorsMap") + .Input("sparse_indices: int64") + .Input("sparse_values: T") + .Input("sparse_shape: int64") + .Output("sparse_handles: int64") + .Attr("T: type") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. + +A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, +`sparse_values`, and `sparse_shape`, where + +```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` + +An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` +having a first `sparse_indices` column taking values between `[0, N)`, where +the minibatch size `N == sparse_shape[0]`. + +The input `SparseTensor` must have rank `R` greater than 1, and the first +dimension is treated as the minibatch dimension. Elements of the `SparseTensor` +must be sorted in increasing order of this first dimension. The stored +`SparseTensor` objects pointed to by each row of the output `sparse_handles` +will have rank `R-1`. + +The `SparseTensor` values can then be read out as part of a minibatch by passing +the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure +the correct `SparseTensorsMap` is accessed, ensure that the same +`container` and `shared_name` are passed to that Op. If no `shared_name` +is provided here, instead use the *name* of the Operation created by calling +`AddManySparseToTensorsMap` as the `shared_name` passed to +`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. + +sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. + `sparse_indices[:, 0]` must be ordered values in `[0, N)`. +sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. + The minibatch size `N == sparse_shape[0]`. +sparse_handles: 1-D. The handles of the `SparseTensor` now stored in the + `SparseTensorsMap`. Shape: `[N]`. +container: The container name for the `SparseTensorsMap` created by this op. +shared_name: The shared name for the `SparseTensorsMap` created by this op. + If blank, the new Operation's unique name is used. +)doc"); + +REGISTER_OP("TakeManySparseFromTensorsMap") + .Input("sparse_handles: int64") + .Output("sparse_indices: int64") + .Output("sparse_values: dtype") + .Output("sparse_shape: int64") + .Attr("dtype: type") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + // serialized sparse is [?,1] matrix. + ShapeHandle sparse_handles; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles)); + + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. + +The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where +`N` is the minibatch size and the rows correspond to the output handles of +`AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the +original `SparseTensor` objects that went into the given input ops must all +match. When the final `SparseTensor` is created, it has rank one +higher than the ranks of the incoming `SparseTensor` objects +(they have been concatenated along a new row dimension on the left). + +The output `SparseTensor` object's shape values for all dimensions but the +first are the max across the input `SparseTensor` objects' shape values +for the corresponding dimensions. Its first shape value is `N`, the minibatch +size. + +The input `SparseTensor` objects' indices are assumed ordered in +standard lexicographic order. If this is not the case, after this +step run `SparseReorder` to restore index ordering. + +For example, if the handles represent an input, which is a `[2, 3]` matrix +representing two original `SparseTensor` objects: + +``` + index = [ 0] + [10] + [20] + values = [1, 2, 3] + shape = [50] +``` + +and + +``` + index = [ 2] + [10] + values = [4, 5] + shape = [30] +``` + +then the final `SparseTensor` will be: + +``` + index = [0 0] + [0 10] + [0 20] + [1 2] + [1 10] + values = [1, 2, 3, 4, 5] + shape = [2 50] +``` + +sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. + Shape: `[N]`. +sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +dtype: The `dtype` of the `SparseTensor` objects stored in the + `SparseTensorsMap`. +container: The container name for the `SparseTensorsMap` read by this op. +shared_name: The shared name for the `SparseTensorsMap` read by this op. + It should not be blank; rather the `shared_name` or unique Operation name + of the Op that created the original `SparseTensorsMap` should be used. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/ops/sparse_ops_test.cc b/tensorflow/core/ops/sparse_ops_test.cc index 3b738c26044edb..eaaaea2e989bcc 100644 --- a/tensorflow/core/ops/sparse_ops_test.cc +++ b/tensorflow/core/ops/sparse_ops_test.cc @@ -292,4 +292,39 @@ TEST(SparseOpsTest, SparseDenseCwise_ShapeFn) { } } +TEST(SparseOpsTest, AddSparseToTensorsMap_ShapeFn) { + ShapeInferenceTestOp op("AddSparseToTensorsMap"); + + // Rank checks. + INFER_ERROR("must be rank 2", op, "[1];?;?"); + INFER_ERROR("must be rank 1", op, "?;[];?"); + INFER_ERROR("must be rank 1", op, "?;?;[]"); + + // output is always scalar + INFER_OK(op, "?;?;?", "[]"); +} + +TEST(SparseOpsTest, AddManySparseToTensorsMap_ShapeFn) { + ShapeInferenceTestOp op("AddManySparseToTensorsMap"); + + // Rank checks. + INFER_ERROR("must be rank 2", op, "[1];?;?"); + INFER_ERROR("must be rank 1", op, "?;[];?"); + INFER_ERROR("must be rank 1", op, "?;?;[]"); + + // output is always matrix of [?]. + INFER_OK(op, "?;?;?", "[?]"); +} + +TEST(SparseOpsTest, TakeManySparseFromTensorsMap_ShapeFn) { + ShapeInferenceTestOp op("TakeManySparseFromTensorsMap"); + + // Rank checks. + INFER_ERROR("must be rank 1", op, "[?,1]"); + + // output is always [?,?];[?];[?]. + INFER_OK(op, "?", "[?,?];[?];[?]"); + INFER_OK(op, "[?]", "[?,?];[?];[?]"); +} + } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 94168433ae3a5c..f911d8c1c7f2a5 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -921,7 +921,7 @@ cuda_py_test( ], ) -cuda_py_test( +tf_py_test( name = "sparse_serialization_ops_test", size = "small", srcs = ["sparse_serialization_ops_test.py"], @@ -930,6 +930,15 @@ cuda_py_test( ], ) +tf_py_test( + name = "sparse_tensors_map_ops_test", + size = "small", + srcs = ["sparse_tensors_map_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + ], +) + cuda_py_test( name = "sparse_tensor_dense_matmul_grad_test", size = "small", diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py new file mode 100644 index 00000000000000..56e9701494d1cc --- /dev/null +++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py @@ -0,0 +1,234 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for SparseTensorsMap.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.python.ops import sparse_ops + +# pylint: disable=protected-access +add_sparse_to_tensors_map = sparse_ops._add_sparse_to_tensors_map +add_many_sparse_to_tensors_map = sparse_ops._add_many_sparse_to_tensors_map +take_many_sparse_from_tensors_map = ( + sparse_ops._take_many_sparse_from_tensors_map) +# pylint: enable=protected-access + + +class SparseTensorsMapTest(tf.test.TestCase): + + def _SparseTensorPlaceholder(self, dtype=None): + if dtype is None: dtype = tf.int32 + return tf.SparseTensor( + tf.placeholder(tf.int64), + tf.placeholder(dtype), + tf.placeholder(tf.int64)) + + def _SparseTensorValue_5x6(self, permutation): + ind = np.array([ + [0, 0], + [1, 0], [1, 3], [1, 4], + [3, 2], [3, 3]]).astype(np.int64) + val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32) + + ind = ind[permutation] + val = val[permutation] + + shape = np.array([5, 6]).astype(np.int64) + return tf.SparseTensorValue(ind, val, shape) + + def _SparseTensorValue_3x4(self, permutation): + ind = np.array([ + [0, 0], + [1, 0], [1, 2], [1, 3], + [2, 2], [2, 3]]).astype(np.int64) + val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32) + + ind = ind[permutation] + val = val[permutation] + + shape = np.array([3, 4]).astype(np.int64) + return tf.SparseTensorValue(ind, val, shape) + + def _SparseTensorValue_1x1x1(self): + ind = np.array([[0, 0, 0]]).astype(np.int64) + val = np.array([0]).astype(np.int32) + shape = np.array([3, 4, 5]).astype(np.int64) + return tf.SparseTensorValue(ind, val, shape) + + def testAddTakeMany(self): + with self.test_session(graph=tf.Graph(), use_gpu=False) as sess: + sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) + sp_input1 = self._SparseTensorValue_3x4(np.arange(6)) + handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a") + handle1 = add_sparse_to_tensors_map(sp_input1, shared_name="a") + self.assertEqual(handle0.get_shape(), ()) + handles_concat = tf.pack([handle0, handle1]) + + sp_out = take_many_sparse_from_tensors_map( + sparse_map_op=handle0.op, sparse_handles=handles_concat) + + combined_indices, combined_values, combined_shape = sess.run(sp_out) + + self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 + self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0]) + self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 + self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0]) + self.assertAllEqual(combined_values[:6], sp_input0[1]) + self.assertAllEqual(combined_values[6:], sp_input1[1]) + self.assertAllEqual(combined_shape, [2, 5, 6]) + + def testFeedAddTakeMany(self): + with self.test_session(use_gpu=False) as sess: + sp_input = self._SparseTensorPlaceholder() + input0_val = self._SparseTensorValue_5x6(np.arange(6)) + input1_val = self._SparseTensorValue_3x4(np.arange(6)) + handle = add_sparse_to_tensors_map(sp_input) + + handle0_value = sess.run( + handle, feed_dict={sp_input: input0_val}) + handle1_value = sess.run( + handle, feed_dict={sp_input: input1_val}) + + sparse_handles = tf.convert_to_tensor( + [handle0_value, handle1_value], dtype=tf.int64) + + sp_roundtrip = take_many_sparse_from_tensors_map( + sparse_map_op=handle.op, sparse_handles=sparse_handles) + + combined_indices, combined_values, combined_shape = sess.run( + sp_roundtrip) + + self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 + self.assertAllEqual(combined_indices[:6, 1:], input0_val[0]) + self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 + self.assertAllEqual(combined_indices[6:, 1:], input1_val[0]) + self.assertAllEqual(combined_values[:6], input0_val[1]) + self.assertAllEqual(combined_values[6:], input1_val[1]) + self.assertAllEqual(combined_shape, [2, 5, 6]) + + def testAddManyTakeManyRoundTrip(self): + with self.test_session(use_gpu=False) as sess: + # N == 4 because shape_value == [4, 5] + indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64) + values_value = np.array([b"a", b"b", b"c"]) + shape_value = np.array([4, 5], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder(dtype=tf.string) + handles = add_many_sparse_to_tensors_map(sparse_tensor) + roundtrip = take_many_sparse_from_tensors_map( + sparse_map_op=handles.op, sparse_handles=handles) + handles_value, roundtrip_value = sess.run( + [handles, roundtrip], + feed_dict={sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.shape: shape_value}) + self.assertEqual(handles_value.shape, (4,)) + self.assertAllEqual(roundtrip_value.indices, indices_value) + self.assertAllEqual(roundtrip_value.values, values_value) + self.assertAllEqual(roundtrip_value.shape, shape_value) + + def testDeserializeFailsInconsistentRank(self): + with self.test_session(use_gpu=False) as sess: + sp_input = self._SparseTensorPlaceholder() + input0_val = self._SparseTensorValue_5x6(np.arange(6)) + input1_val = self._SparseTensorValue_1x1x1() + handle = add_sparse_to_tensors_map(sp_input) + + handle0_value = sess.run( + handle, feed_dict={sp_input: input0_val}) + handle1_value = sess.run( + handle, feed_dict={sp_input: input1_val}) + + handle_concat = tf.convert_to_tensor( + [handle0_value, handle1_value], dtype=tf.int64) + + sp_roundtrip = take_many_sparse_from_tensors_map( + sparse_map_op=handle.op, sparse_handles=handle_concat) + + with self.assertRaisesOpError( + r"Inconsistent rank across SparseTensors: rank prior to " + r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"): + sess.run(sp_roundtrip) + + def testTakeManyFailsWrongInputOp(self): + with self.test_session(use_gpu=False) as sess: + input_val = self._SparseTensorValue_5x6(np.arange(6)) + handle = add_sparse_to_tensors_map(input_val) + handle_value = sess.run(handle) + bad_handle = handle_value + 10 + sp_roundtrip = take_many_sparse_from_tensors_map( + sparse_map_op=handle.op, + sparse_handles=[handle_value, bad_handle]) + + with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"): + sess.run(sp_roundtrip) + + +class BenchmarkSparseTensorsMapVsSerialization(tf.test.Benchmark): + + def benchmarkVeryLarge2DFloatSparseTensor(self): + np.random.seed(127) + num_elements = 10000 + batch_size = 64 + indices_batch = np.random.randint( + batch_size, size=num_elements, dtype=np.int64) + indices_value = np.arange(num_elements, dtype=np.int64) + indices = np.asarray( + sorted(zip(indices_batch, indices_value)), dtype=np.int64) + values = ["feature_value_for_embedding_lookup"] * num_elements + shape = np.asarray([batch_size, num_elements], dtype=np.int64) + with tf.Session() as sess: + with tf.device("/cpu:0"): + indices = tf.Variable(indices) + values = tf.Variable(values) + shape = tf.Variable(shape) + st = tf.SparseTensor(indices, values, shape) + + st_handles = add_many_sparse_to_tensors_map(st) + st_roundtrip = take_many_sparse_from_tensors_map( + sparse_map_op=st_handles.op, sparse_handles=st_handles) + st_roundtrip_op = st_roundtrip.values.op + + st_serialized = tf.serialize_many_sparse(st) + st_deserialized = tf.deserialize_many_sparse( + st_serialized, dtype=values.dtype) + st_deserialized_op = st_deserialized.values.op + + tf.initialize_all_variables().run() + + st_roundtrip_values = sess.run(st_roundtrip) + st_deserialized_values = sess.run(st_deserialized) + np.testing.assert_equal( + st_roundtrip_values.values, st_deserialized_values.values) + np.testing.assert_equal( + st_roundtrip_values.indices, st_deserialized_values.indices) + np.testing.assert_equal( + st_roundtrip_values.shape, st_deserialized_values.shape) + + self.run_op_benchmark( + sess, st_roundtrip_op, min_iters=2000, + name="benchmark_very_large_2d_float_st_tensor_maps") + self.run_op_benchmark( + sess, st_deserialized_op, min_iters=2000, + name="benchmark_very_large_2d_float_st_serialization") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 63a4f854a0e5b5..015dd3fc6027dd 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -227,6 +227,9 @@ TemporaryVariable DestroyTemporaryVariable # sparse_ops +AddSparseToTensorsMap +AddManySparseToTensorsMap +TakeManySparseFromTensorsMap DeserializeManySparse SerializeManySparse SerializeSparse diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 9669f4e21c6fdd..68c4f8ca236df8 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1548,3 +1548,150 @@ def sparse_transpose(sp_input, perm=None, name=None): transposed_dense_shape) transposed_st = sparse_reorder(transposed_st) return transposed_st + + +def _add_sparse_to_tensors_map(sp_input, container=None, + shared_name=None, name=None): + """Add a `SparseTensor` to a `SparseTensorsMap` and return its handle. + + Args: + sp_input: The input `SparseTensor`. + container: The container for the underlying `SparseTensorsMap` (optional). + shared_name: The shared name for the underlying `SparseTensorsMap` + (optional, defaults to the name of the newly created op). + name: A name prefix for the returned tensors (optional). + + Returns: + A string 1-vector (1D `Tensor`), with the single element representing the + a unique handle to a `SparseTensor` stored by the `SparseTensorMap` + underlying this op. + + Raises: + TypeError: If `sp_input` is not a `SparseTensor`. + """ + sp_input = _convert_to_sparse_tensor(sp_input) + + return gen_sparse_ops._add_sparse_to_tensors_map( + sp_input.indices, sp_input.values, sp_input.shape, + container=container, shared_name=shared_name, name=name) + + +def _add_many_sparse_to_tensors_map(sp_input, container=None, + shared_name=None, name=None): + """Add a minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. + + The `SparseTensor` must have rank `R` greater than 1, and the first dimension + is treated as the minibatch dimension. Elements of the `SparseTensor` + must be sorted in increasing order of this first dimension. The serialized + `SparseTensor` objects going into each row of the output `Tensor` will have + rank `R-1`. + + The minibatch size `N` is extracted from `sparse_shape[0]`. + + Args: + sp_input: The input rank `R` `SparseTensor`. + container: The container for the underlying `SparseTensorsMap` (optional). + shared_name: The shared name for the underlying `SparseTensorsMap` + (optional, defaults to the name of the newly created op). + name: A name prefix for the returned tensors (optional). + + Returns: + A string matrix (2-D `Tensor`) with `N` rows and `1` column. + Each row represents a unique handle to a `SparseTensor` stored by + the `SparseTensorMap` underlying this op. + + Raises: + TypeError: If `sp_input` is not a `SparseTensor`. + """ + sp_input = _convert_to_sparse_tensor(sp_input) + + return gen_sparse_ops._add_many_sparse_to_tensors_map( + sp_input.indices, sp_input.values, sp_input.shape, + container=container, shared_name=shared_name, name=name) + + +def _take_many_sparse_from_tensors_map( + sparse_map_op, sparse_handles, rank=None, name=None): + """Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. + + The input `sparse_handles` must be a string matrix of shape `[N, 1]` where + `N` is the minibatch size and the rows correspond to packed outputs of + `add_sparse_to_tensors_map`. The ranks of the original `SparseTensor` objects + must all match. When the final `SparseTensor` is created, it has rank one + higher than the ranks of the incoming `SparseTensor` objects (they have been + concatenated along a new row dimension). + + The output `SparseTensor` object's shape values for all dimensions but the + first are the max across the input `SparseTensor` objects' shape values + for the corresponding dimensions. Its first shape value is `N`, the minibatch + size. + + The input `SparseTensor` objects' indices are assumed ordered in + standard lexicographic order. If this is not the case, after this + step run `sparse_reorder` to restore index ordering. + + For example, if the serialized input is a `[2, 3]` matrix representing two + original `SparseTensor` objects: + + index = [ 0] + [10] + [20] + values = [1, 2, 3] + shape = [50] + + and + + index = [ 2] + [10] + values = [4, 5] + shape = [30] + + then the final deserialized `SparseTensor` will be: + + index = [0 0] + [0 10] + [0 20] + [1 2] + [1 10] + values = [1, 2, 3, 4, 5] + shape = [2 50] + + Args: + sparse_map_op: The `Operation` that created the original handles. + Usually this is, e.g., `add_sparse_to_tensors_map(...).op`. + sparse_handles: 2-D `Tensor` of type `string` of shape `[N, 1]`. + The serialized and packed `SparseTensor` objects. + rank: (optional) Python int, the rank of the `SparseTensor` objects. + name: A name prefix for the returned tensors (optional) + + Returns: + A `SparseTensor` representing the deserialized `SparseTensor`s, + concatenated along the `SparseTensor`s' first dimension. + + All of the serialized `SparseTensor`s must have had the same rank and type. + """ + if not isinstance(sparse_map_op, ops.Operation): + raise TypeError("sparse_map_op be an Operation") + if sparse_map_op.type not in ("AddSparseToTensorsMap", + "AddManySparseToTensorsMap"): + raise TypeError("sparse_map_op must be one of AddSparseToTensorsMap or " + "AddSparseToTensorsMap") + with ops.colocate_with(sparse_map_op): + shared_name = sparse_map_op.get_attr("shared_name") or sparse_map_op.name + output_indices, output_values, output_shape = ( + gen_sparse_ops._take_many_sparse_from_tensors_map( + sparse_handles, dtype=sparse_map_op.get_attr("T"), + container=sparse_map_op.get_attr("container"), + shared_name=shared_name, name=name)) + + # Feed rank data back in, if available + output_indices.set_shape([None, rank]) + output_shape.set_shape([rank]) + + return ops.SparseTensor(output_indices, output_values, output_shape) + + +ops.RegisterShape("AddSparseToTensorsMap")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("AddManySparseToTensorsMap")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TakeManySparseFromTensorsMap")( + common_shapes.call_cpp_shape_fn)