diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8a154f947f39db..275568bbdd8545 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -215,6 +215,7 @@ tf_kernel_library( "split_lib.h", ], deps = [ + ":cuda_device_array", "//tensorflow/core:framework", "//third_party/eigen3", ], @@ -246,6 +247,18 @@ cc_header_only_library( deps = [":bounds_check"], ) +cc_library( + name = "cuda_device_array", + hdrs = [ + "cuda_device_array.h", + "cuda_device_array_gpu.h", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/core:lib", + ], +) + cc_library( name = "eigen_helpers", hdrs = [ @@ -318,6 +331,7 @@ tf_kernel_libraries( ":batchtospace_op", ":bounds_check", ":concat_lib", + ":cuda_device_array", ":depth_space_ops", ":fill_functor", ":ops_util", diff --git a/tensorflow/core/kernels/cuda_device_array.h b/tensorflow/core/kernels/cuda_device_array.h new file mode 100644 index 00000000000000..26e77fb7ce59e0 --- /dev/null +++ b/tensorflow/core/kernels/cuda_device_array.h @@ -0,0 +1,120 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ + +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/cuda_device_array_gpu.h" + +namespace tensorflow { + +// Create an array of value on the host, to be sent to kernel using +// CudaDeviceArrayStruct. +// +// Usage: +// int size = ...; +// CudaDeviceArrayOnHost ptrs(context, size); +// OP_REQUIRES_OK(ptrs.Init()); +// for (int i = 0; i < size; ++i) { +// ptrs.Set(i, ...); +// } +// OP_REQUIRES_OK(ptrs.Finalize()); +// launchKernel(..., ptrs.data, ...); +// +// ValueType must be memcopyable. +template +class CudaDeviceArrayOnHost { + public: + CudaDeviceArrayOnHost(OpKernelContext* context, int32 size) + : context_(context), + total_bytes_(static_cast(size) * sizeof(ValueType)) { + data_.size = size; + } + + Status Init() { + if (inlined()) { + values_ = data_.inline_values; + return Status::OK(); + } + + // Out-of-line: allocate data that will be memcopied. + AllocatorAttributes attr; + attr.set_on_host(true); + attr.set_gpu_compatible(true); + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_INT8, TensorShape{total_bytes_}, + &out_of_line_values_on_host_, attr)); + values_ = reinterpret_cast( + out_of_line_values_on_host_.flat().data()); + return Status::OK(); + } + + void Set(int index, ValueType val) { + DCHECK(values_); // ensure Init was called. + DCHECK_LT(index, data_.size); + *(values_ + index) = val; + } + + Status Finalize() { + if (inlined()) { + return Status::OK(); + } + + // Out-of-line - copy pointers to device. + auto stream = context_->op_device_context()->stream(); + TensorReference tensor_ref(out_of_line_values_on_host_); + TF_RETURN_IF_ERROR(context_->allocate_temp( + DT_INT8, TensorShape{total_bytes_}, &out_of_line_values_on_gpu_)); + perftools::gputools::DeviceMemoryBase output_values_base{ + out_of_line_values_on_gpu_.flat().data(), + static_cast(total_bytes_)}; + stream->ThenMemcpy(&output_values_base, + out_of_line_values_on_host_.flat().data(), + total_bytes_); + context_->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, [tensor_ref]() { tensor_ref.Unref(); }); + data_.out_of_line_values = reinterpret_cast( + out_of_line_values_on_gpu_.flat().data()); + return Status::OK(); + } + + const CudaDeviceArrayStruct& data() const { + // Ensure Finalize is called. + DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized()); + return data_; + } + + private: + bool inlined() const { return data_.size <= MaxInlineValues; } + + OpKernelContext* const context_; + const int64 total_bytes_; // total size of all pointers. + ValueType* values_ = nullptr; + CudaDeviceArrayStruct data_; + + Tensor out_of_line_values_on_host_; + Tensor out_of_line_values_on_gpu_; + + TF_DISALLOW_COPY_AND_ASSIGN(CudaDeviceArrayOnHost); +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ diff --git a/tensorflow/core/kernels/cuda_device_array_gpu.h b/tensorflow/core/kernels/cuda_device_array_gpu.h new file mode 100644 index 00000000000000..f6ebcccb0d5088 --- /dev/null +++ b/tensorflow/core/kernels/cuda_device_array_gpu.h @@ -0,0 +1,50 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains structs and functions to be included in device code. + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ + +#if GOOGLE_CUDA + +namespace tensorflow { + +static constexpr int kMaxInlineCudaPointers = 8; +// To decode on the device side, use GetCudaDeviceArrayOnDevice. +// To encode on the host side, use CudaDeviceArrayOnHost. +template +struct CudaDeviceArrayStruct { + int32 size; + // used if size <= MaxInlineValues; + ValueType inline_values[MaxInlineValues]; + ValueType* out_of_line_values = nullptr; // used if size > MaxInlineValues; +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( + CudaDeviceArrayStruct* data) { + if (data->size <= MaxInlineValues) { + return data->inline_values; + } else { + return data->out_of_line_values; + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc index 05eed86e42638a..5e82352d639a3a 100644 --- a/tensorflow/core/kernels/split_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cuda_device_array_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -46,9 +47,12 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); namespace { template -__global__ void SplitOpKernel(const T* input, int32 num_split, - int32 prefix_dim_size, int32 split_dim_size, - int32 suffix_dim_size, T** output_ptrs) { +__global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + CudaDeviceArrayStruct output_ptr_data) { + const int32 num_split = output_ptr_data.size; + T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); + eigen_assert(blockDim.y == 1); eigen_assert(blockDim.z == 1); eigen_assert(split_dim_size % num_split == 0); @@ -79,16 +83,16 @@ __global__ void SplitOpKernel(const T* input, int32 num_split, template struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 num_split, - int32 prefix_dim_size, int32 split_dim_size, int32 suffix_dim_size, - T** output_ptrs) { + void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + const CudaDeviceArrayStruct& output_ptr_data) { CudaLaunchConfig config = GetCudaLaunchConfig( prefix_dim_size * split_dim_size * suffix_dim_size, d); SplitOpKernel< T><<>>( - input, num_split, prefix_dim_size, split_dim_size, suffix_dim_size, - static_cast(output_ptrs)); + input, prefix_dim_size, split_dim_size, suffix_dim_size, + output_ptr_data); } }; diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 1d42b93c605b57..30346a914f8975 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/kernels/cuda_device_array.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -182,9 +183,9 @@ class SplitOpCPU : public SplitOpBase { template struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 split_dim, - int32 prefix_dim_size, int32 split_dim_size, int32 suffix_dim_size, - T** output_ptrs_vec); + void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + const CudaDeviceArrayStruct& output_ptr_data); }; // Partial specialization for GPU @@ -219,44 +220,24 @@ class SplitOpGPU : public SplitOpBase { TensorShape output_shape(input_shape); output_shape.set_dim(split_dim, split_dim_output_size); - AllocatorAttributes attr; - attr.set_on_host(true); - attr.set_gpu_compatible(true); - - Tensor output_ptrs_on_host; - Tensor output_ptrs_on_gpu; - int64 output_ptrs_total_bytes = static_cast(sizeof(T*) * num_split); - OP_REQUIRES_OK(context, context->allocate_temp( - DT_INT8, TensorShape{output_ptrs_total_bytes}, - &output_ptrs_on_host, attr)); - OP_REQUIRES_OK(context, context->allocate_temp( - DT_INT8, TensorShape{output_ptrs_total_bytes}, - &output_ptrs_on_gpu)); - T** output_ptrs_on_host_arr = - reinterpret_cast(output_ptrs_on_host.flat().data()); + CudaDeviceArrayOnHost ptrs(context, num_split); + OP_REQUIRES_OK(context, ptrs.Init()); + for (int i = 0; i < num_split; ++i) { Tensor* result = nullptr; OP_REQUIRES_OK(context, context->allocate_output(i, output_shape, &result)); - output_ptrs_on_host_arr[i] = result->flat().data(); + ptrs.Set(i, result->flat().data()); } if (prefix_dim_size * split_dim_output_size * suffix_dim_size == 0) { return; } - auto stream = context->op_device_context()->stream(); - perftools::gputools::DeviceMemoryBase output_ptrs_base{ - output_ptrs_on_gpu.flat().data(), static_cast(num_split)}; - TensorReference tensor_ref(output_ptrs_on_host); - stream->ThenMemcpy(&output_ptrs_base, - output_ptrs_on_host.flat().data(), - output_ptrs_total_bytes); - context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( - stream, [tensor_ref]() { tensor_ref.Unref(); }); - SplitOpGPULaunch().Run( - context->eigen_device(), input.flat().data(), num_split, - prefix_dim_size, split_dim_size, suffix_dim_size, - reinterpret_cast(output_ptrs_on_gpu.flat().data())); - OP_REQUIRES(context, stream->ok(), + OP_REQUIRES_OK(context, ptrs.Finalize()); + + SplitOpGPULaunch().Run(context->eigen_device(), + input.flat().data(), prefix_dim_size, + split_dim_size, suffix_dim_size, ptrs.data()); + OP_REQUIRES(context, context->op_device_context()->stream()->ok(), errors::Internal("Launch of gpu kernel for SplitOp failed")); } }; diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index d5c1b9883bde10..7834599f2306f1 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -85,11 +85,14 @@ def testSplitDim0(self): self._compare(np.random.rand(6, 7, 18).astype("f"), 0, 3, use_gpu) self._compare(np.random.rand(6, 7, 9).astype("f"), 0, 3, use_gpu) - def _RunAndVerify(self, use_gpu): + def _RunAndVerify(self, use_gpu, large_num_splits=False): # Random dims of rank 5 shape = np.random.randint(0, 5, size=5) split_dim = np.random.randint(0, 5) - num_split = np.random.randint(2, 8) + if large_num_splits: + num_split = np.random.randint(9, 15) + else: + num_split = np.random.randint(2, 8) shape[split_dim] = np.random.randint(2, 5) * num_split inp = np.random.rand(*shape).astype("f") with self.test_session(use_gpu=use_gpu) as sess: @@ -106,6 +109,7 @@ def testRandom(self): for _ in range(5): self._RunAndVerify(use_gpu=False) self._RunAndVerify(use_gpu=True) + self._RunAndVerify(use_gpu=True, large_num_splits=True) def _testGradientsSimple(self, use_gpu): inp = np.random.rand(4, 4).astype("f")