Skip to content

Commit

Permalink
Add checkpoint support for cuDNN RNN.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yao Zhang authored and tensorflower-gardener committed Dec 7, 2016
1 parent 661dee2 commit ebcf0ee
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 55 deletions.
3 changes: 2 additions & 1 deletion tensorflow/contrib/cudnn_rnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
# APIs are meant to change over time.
package(
default_visibility = ["//visibility:private"],
features = ["-parse_headers"],
)

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")

tf_custom_op_library(
name = "python/ops/_cudnn_rnn_ops.so",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/cudnn_rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanh
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import RNNParamsSaveable
196 changes: 196 additions & 0 deletions tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.

#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA

/*
Expand Down Expand Up @@ -78,6 +79,12 @@ using GPUDevice = Eigen::GpuDevice;
template <typename Device, typename T, typename Index>
class CudnnRNNParamsSizeOp;

template <typename Device, typename T>
class CudnnRNNParamsToCanonical;

template <typename Device, typename T>
class CudnnRNNCanonicalToParams;

template <typename Device, typename T>
class CudnnRNNForwardOp;

Expand All @@ -96,6 +103,7 @@ using perftools::gputools::dnn::RnnInputMode;
using perftools::gputools::dnn::RnnDirectionMode;
using perftools::gputools::dnn::ToDataType;
using perftools::gputools::DeviceMemory;
using perftools::gputools::DeviceMemoryBase;
using perftools::gputools::ScratchAllocator;
using perftools::gputools::port::StatusOr;

Expand Down Expand Up @@ -184,6 +192,16 @@ DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
tensor->template flat<T>().size() * sizeof(T));
}

DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
int64 offset, int64 size) {
const void* base_ptr = device_memory.opaque();
void* offset_ptr =
const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
CHECK(offset + size <= device_memory.size())
<< "The slice is not within the region of DeviceMemory.";
return DeviceMemoryBase(offset_ptr, size);
}

inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
return s.ok() ? Status::OK() : Status(static_cast<tensorflow::error::Code>(
static_cast<int>(s.code())),
Expand Down Expand Up @@ -357,6 +375,28 @@ Status ExtractForwardInput(OpKernelContext* context,

using perftools::gputools::dnn::RnnDescriptor;

template <typename T>
void RestoreParams(const OpInputList params_input,
const std::vector<RnnDescriptor::ParamsRegion>& params,
DeviceMemoryBase* data_dst,
perftools::gputools::Stream* stream) {
int num_params = params.size();
CHECK(params_input.size() == num_params)
<< "Number of params mismatch. Expected " << params_input.size()
<< ", got " << num_params;
for (int i = 0; i < params.size(); i++) {
int64 size_in_bytes = params[i].size;
int64 size = size_in_bytes / sizeof(T);
CHECK(size == params_input[i].NumElements())
<< "Params size mismatch. Expected " << size << ", got "
<< params_input[i].NumElements();
auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
DeviceMemoryBase data_dst_ptr =
SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
}
}

} // namespace

// A common base class for RNN kernels. It extracts common attributes and
Expand Down Expand Up @@ -458,6 +498,162 @@ REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")
.TypeConstraint<int32>("S"),
CudnnRNNParamsSizeOp<GPUDevice, float, int32>);

// Convert weight and bias params from a platform-specific layout to the
// canonical form.
template <typename T>
class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
typedef GPUDevice Device;
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
}

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(3);
auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
auto* stream = context->op_device_context()->stream();

std::unique_ptr<RnnDescriptor> rnn_desc;
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
CHECK(params_size_in_bytes % sizeof(T) == 0)
<< "params_size_in_bytes must be multiple of element size";

const Tensor* num_units_t = nullptr;
OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
<< "num_units is not a scalar";
int num_units = num_units_t->scalar<int>()();

const Tensor* input_size_t = nullptr;
OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
<< "input_size is not a scalar";
int input_size = input_size_t->scalar<int>()();

const Tensor* num_layers_t = nullptr;
OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
<< "num_layers is not a scalar";
int num_layers = num_layers_t->scalar<int>()();
int num_params_per_layer = num_params_ / num_layers;

CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
<< "Number of params mismatch. Expected " << num_params_ << ", got "
<< rnn_desc->ParamsWeightRegions().size();
for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
int64 size = size_in_bytes / sizeof(T);
int width = (i < num_params_per_layer / 2) ? input_size : num_units;
int height = num_units;
CHECK(size == width * height) << "Params size mismatch. Expected "
<< width * height << ", got " << size;
// If data is aligned, use slice view to avoid expensive memcpy.
bool start_aligned =
rnn_desc->ParamsWeightRegions()[i].offset % EIGEN_MAX_ALIGN_BYTES ==
0;
bool size_aligned = size_in_bytes % EIGEN_MAX_ALIGN_BYTES == 0;
if (start_aligned && size_aligned) {
int start = rnn_desc->ParamsWeightRegions()[i].offset / sizeof(T);
int end = start + size_in_bytes / sizeof(T);
context->set_output(i, input.Slice(start, end));
} else {
Tensor* output = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(i, TensorShape({width, height}), &output));
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
input_ptr, rnn_desc->ParamsWeightRegions()[i].offset,
size_in_bytes);
auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
}
}

CHECK(num_params_ == rnn_desc->ParamsBiasRegions().size())
<< "Number of params mismatch. Expected " << num_params_ << ", got "
<< rnn_desc->ParamsBiasRegions().size();
for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
int64 size = size_in_bytes / sizeof(T);
CHECK(size == num_units) << "Params size mismatch. Expected " << num_units
<< ", got " << size;
// If data is aligned, use slice view to avoid expensive memcpy.
bool start_aligned =
rnn_desc->ParamsBiasRegions()[i].offset % EIGEN_MAX_ALIGN_BYTES == 0;
bool size_aligned = size_in_bytes % EIGEN_MAX_ALIGN_BYTES == 0;
if (start_aligned && size_aligned) {
int start = rnn_desc->ParamsBiasRegions()[i].offset / sizeof(T);
int end = start + size_in_bytes / sizeof(T);
context->set_output(num_params_ + i, input.Slice(start, end));
} else {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(num_params_ + i,
TensorShape({size}), &output));
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
}
}
}

private:
int num_params_;
};

REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical")
.Device(DEVICE_GPU)
.HostMemory("num_layers")
.HostMemory("num_units")
.HostMemory("input_size")
.TypeConstraint<float>("T"),
CudnnRNNParamsToCanonical<GPUDevice, float>);

// Convert weight and bias params from the canonical form to a
// platform-specific layout.
template <typename T>
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
typedef GPUDevice Device;
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}

void Compute(OpKernelContext* context) override {
std::unique_ptr<RnnDescriptor> rnn_desc;
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
CHECK(params_size_in_bytes % sizeof(T) == 0)
<< "params_size_in_bytes must be multiple of element size";
Tensor* output = nullptr;
int params_size = params_size_in_bytes / sizeof(T);
OP_REQUIRES_OK(context,
context->allocate_output(0, {params_size}, &output));
auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
auto* stream = context->op_device_context()->stream();

OpInputList weights;
OP_REQUIRES_OK(context, context->input_list("weights", &weights));
RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
stream);

OpInputList biases;
OP_REQUIRES_OK(context, context->input_list("biases", &biases));
RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
stream);
}
};

REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams")
.Device(DEVICE_GPU)
.HostMemory("num_layers")
.HostMemory("num_units")
.HostMemory("input_size")
.TypeConstraint<float>("T"),
CudnnRNNCanonicalToParams<GPUDevice, float>);

// Run the forward operation of the RNN model.
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
Expand Down
77 changes: 53 additions & 24 deletions tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ limitations under the License.
namespace tensorflow {
namespace {

constexpr auto kCudnnRNNCommonInputs = R"doc(
num_layers: Specifies the number of layers in the RNN model.
num_units: Specifies the size of the hidden state.
input_size: Specifies the size of the input state.
)doc";

constexpr auto kCudnnRNNCommonAttrs = R"doc(
rnn_mode: Indicates the type of the RNN model.
input_mode: Indicate whether there is a linear projection between the input and
Expand All @@ -47,12 +53,12 @@ constexpr auto kRNNInputModeAttrs =
constexpr auto kRNNDirectionAttrs =
"direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";

constexpr auto kCudnnRNNCanonicalParams = R"doc(
canonical_weights: the canonical form of weights that can be used for saving
constexpr auto kCudnnRNNParamsCanonical = R"doc(
weights: the canonical form of weights that can be used for saving
and restoration. They are more likely to be compatible across different
generations.
canonical_biases: the canonical form of biases that can be used for saving and
restoration. They are more likely to be compatible across different
biases: the canonical form of biases that can be used for saving
and restoration. They are more likely to be compatible across different
generations.
)doc";

Expand Down Expand Up @@ -80,11 +86,8 @@ REGISTER_OP("CudnnRNNParamsSize")
Return the params size that can be used by the Cudnn RNN model. Subsequent
weight allocation and initialization should use this size.
)doc",
kCudnnRNNCommonAttrs,
kCudnnRNNCommonInputs, kCudnnRNNCommonAttrs,
R"doc(
num_layers: Specifies the number of layers in the RNN model.
num_units: Specifies the size of the hidden state.
input_size: Specifies the size of the input state.
params_size: The size of the params buffer that should be allocated and
initialized for this RNN model. Note that this params buffer may not be
compatible across GPUs. Please use CudnnRNNParamsWeights and
Expand Down Expand Up @@ -213,46 +216,72 @@ params_backprop: The backprop to the params buffer in the forward pass. Has the
same shape as params.
)doc"));

// NOTE(zhengxq): this is not currently implemented yet. And may subject to
// change.
REGISTER_OP("CudnnRNNParamsToCanonical")
.Input("num_layers: int32")
.Input("num_units: int32")
.Input("input_size: int32")
.Input("params: T")
.Output("canonical_weights: T")
.Output("canonical_biases: T")
.Output("weights: num_params * T")
.Output("biases: num_params * T")
.Attr("T: {float}")
.Attr("N: int >= 1")
.Attr("num_params: int")
.Attr(kRNNModeAttrs)
.Attr(kRNNInputModeAttrs)
.Attr(kRNNDirectionAttrs)
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
int num_params;
c->GetAttr("num_params", &num_params);
// Set shape for weight matrices
for (int i = 0; i < num_params; i++) {
c->set_output(i,
c->Matrix(InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim));
}
// Set shape for bias vectors
for (int i = 0; i < num_params; i++) {
c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
}
return Status::OK();
})
.Doc(strings::StrCat(R"doc(
Retrieves a set of weights from the opaque params buffer that can be saved and
restored in a way compatible with future runs.
)doc",
kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer,
kCudnnRNNCanonicalParams));
kCudnnRNNCommonInputs, kCudnnRNNParamsBuffer, R"doc(
num_params: number of parameter sets for all layers.
Each layer may contain multiple parameter sets, with each set consisting of
a weight matrix and a bias vector.
)doc",
kCudnnRNNParamsCanonical, kCudnnRNNCommonAttrs));

// NOTE(zhengxq): this is not currently implemented yet. And may subject to
// change.
REGISTER_OP("CudnnRNNParamsFromCanonical")
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
.Input("input_size: int32")
.Input("params: Ref(T)")
.Input("canonical_weights: T")
.Input("canonical_biases: T")
.Input("weights: num_params * T")
.Input("biases: num_params * T")
.Output("params: T")
.Attr("T: {float}")
.Attr("N: int >= 1")
.Attr("num_params: int")
.Attr(kRNNModeAttrs)
.Attr(kRNNInputModeAttrs)
.Attr(kRNNDirectionAttrs)
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
})
.Doc(strings::StrCat(R"doc(
Writes a set of weights into the opaque params buffer so they can be used in
upcoming training or inferences.
)doc",
kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer,
kCudnnRNNCanonicalParams));
kCudnnRNNCommonInputs, kCudnnRNNParamsCanonical,
kCudnnRNNParamsBuffer, R"doc(
num_params: number of parameter sets for all layers.
Each layer may contain multiple parameter sets, with each set consisting of
a weight matrix and a bias vector.
)doc",
kCudnnRNNCommonAttrs));

} // namespace tensorflow
Loading

0 comments on commit ebcf0ee

Please sign in to comment.