Skip to content

Commit

Permalink
Move shared validation for StridedSliceOp to a separate header, in pr…
Browse files Browse the repository at this point in the history
…eparation

for calling it from the shape fn.
Change: 132732667
  • Loading branch information
tensorflower-gardener committed Sep 10, 2016
1 parent af02c57 commit f51e196
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 291 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
"util/tensor_slice_reader.h",
"util/tensor_slice_reader_cache.h",
Expand Down
313 changes: 22 additions & 291 deletions tensorflow/core/kernels/strided_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,286 +35,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/util/strided_slice_op.h"

namespace tensorflow {

namespace {

/// Constants
constexpr int32 kShrinkAxis = -1, kNewAxis = -2;

// Sparse slicing specification
// if one does foo[3:5, ..., -3], this will have 3 length tensors
struct StridedSliceSparseSpec {
int64 dims;
int32 num_add_axis_after_ellipsis;
const Tensor& begin_tensor;
const Tensor& end_tensor;
const Tensor& strides_tensor;
const int32 begin_mask, end_mask;
int32 ellipsis_mask;
const int32 new_axis_mask, shrink_axis_mask;
};

// Dense slicing specification
// all ellipses and newaxis' are expanded out. So if
// foo[3:5, ..., -3] where foo is 10 dimensional,
// each inlinedVector will have 10 entries whereas the
// sparse had 3 length tensors.
struct StridedSliceDenseSpec {
const int64 dims;
int32 begin_mask;
int32 end_mask;
gtl::InlinedVector<int64, 4>& begin;
gtl::InlinedVector<int64, 4>& end;
gtl::InlinedVector<int64, 4>& strides;
// This vector helps construct the final shape of the slice.
// The final tensor is reduced in rank whenever a single index e.g. foo[3]
// is called for. The final tensor increases in rank with tf.newaxis
// entries. If an index in this array is positive, the size of the dimension
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
// it will be 1. A shrunk dimension is skipped.
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
// The dense indexed shrink mask is which processing dimensions
// should be shrunk. For example, if foo.shape = (10,10,10,10)
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
// dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
int32 shrink_axis_mask;
};

} // namespace

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

template <class T>
static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
StridedSliceDenseSpec* dense) {
// Build expanded begin, end, strides, begin_mask, end_mask
// to remove any ellipsis
dense->begin.resize(dense->dims);
dense->end.resize(dense->dims);
dense->strides.resize(dense->dims);
// What indices to get the final shape from.
dense->begin_mask = 0;
dense->end_mask = 0;
dense->shrink_axis_mask = 0;
{
int full_index = 0;

const auto& begin_flat = sparse.begin_tensor.flat<T>();
const auto& end_flat = sparse.end_tensor.flat<T>();
const auto& strides_flat = sparse.strides_tensor.flat<T>();

for (int i = 0; i < sparse.dims; i++) {
if ((1 << i) & sparse.ellipsis_mask) {
// Expand the ellipsis into the appropriate indices
// NOTE: this only works because we guaranteed one ellipsis
int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
sparse.num_add_axis_after_ellipsis,
dense->dims);
for (; full_index < next_index; full_index++) {
// new_axis' aren't real axis so you have to skip
dense->begin[full_index] = dense->end[full_index] = 0;
dense->strides[full_index] = 1;
dense->begin_mask |= (1 << full_index);
dense->end_mask |= (1 << full_index);
dense->final_shape_gather_indices.push_back(full_index);
}
} else if ((1 << i) & sparse.new_axis_mask) {
dense->final_shape_gather_indices.push_back(kNewAxis);
} else {
// Gather slicing spec into appropriate index
dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
dense->strides[full_index] =
internal::SubtleMustCopy<T>(strides_flat(i));
if (sparse.begin_mask & (1 << i)) {
dense->begin_mask |= (1 << full_index);
}
if (sparse.end_mask & (1 << i)) {
dense->end_mask |= (1 << full_index);
}
// If shrink, record where to get the dimensionality from (i.e.
// new_axis creates a fake 1 size dimension. Also remember shrink
// axis (now in dense form) so we can ignore dense->end below.
if (sparse.shrink_axis_mask & (1 << i)) {
dense->final_shape_gather_indices.push_back(kShrinkAxis);
dense->shrink_axis_mask |= (1 << full_index);
} else {
dense->final_shape_gather_indices.push_back(full_index);
}
full_index++;
}
}
}
}

// Shared code that is not dependent on the type of T. We do this to reduce
// code size by not duplicating all this for all T (float, double, int32, etc.)
static void SharedValidation(
OpKernelContext* context, const TensorShape& input_shape,
int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
const Tensor& begin_tensor = context->input(1);
const Tensor& end_tensor = context->input(2);
const Tensor& strides_tensor = context->input(3);
OP_REQUIRES(
context, TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(end_tensor.shape()) &&
TensorShapeUtils::IsVector(strides_tensor.shape()) &&
strides_tensor.dims() == 1 &&
strides_tensor.dims() == begin_tensor.dims() &&
strides_tensor.dims() == end_tensor.dims() &&
begin_tensor.dim_size(0) == end_tensor.dim_size(0) &&
begin_tensor.dim_size(0) == strides_tensor.dim_size(0) &&
begin_tensor.dim_size(0) < 32, // using 32 bit masks
errors::InvalidArgument(
"Expected begin, end, and strides to be 1D equal size tensors, ",
"but got shapes ", begin_tensor.shape().DebugString(), ", ",
end_tensor.shape().DebugString(), ", and ",
strides_tensor.shape().DebugString(), " instead."));
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2
// i.e. there exists only no more than one ellipsis
OP_REQUIRES(context,
!ellipsis_mask || (ellipsis_mask & (ellipsis_mask - 1)) == 0,
errors::InvalidArgument("Multiple ellipsis' in slice "
"spec not allowed"));

// Step 1: Account for ellipsis and new axis
//
// Check for ellipses and count how many non-newaxis' there are after
// TODO(aselle): Convert this to do a fast log2 followed by iteration
// counting ones in next guys
bool ellipsis_seen = false;

StridedSliceSparseSpec sparse_spec = {begin_tensor.NumElements(),
0,
begin_tensor,
end_tensor,
strides_tensor,
begin_mask_spec,
end_mask_spec,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask};

for (int32 i = 0; i < sparse_spec.dims; i++) {
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
sparse_spec.num_add_axis_after_ellipsis++;
}
if ((1 << i) & ellipsis_mask) {
ellipsis_seen = true;
}
}
// If no ellipsis insert one at the end
if (!ellipsis_seen) {
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
sparse_spec.dims++; // this effects loop iteration below
}

// Step 2: Make a sparse spec into a full index spec
//
// The sparse spec does not corresopnds to the number of dimensions
// Make a dense spec that corresponds to thte number of dimensions
//
// For example suppose foo[...,3:] on foo.shape=(2,2,3) then
// we need to produce the missing begin_mask for the the first two
// dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
// we achieve begin_mask=6, end_mask=7
StridedSliceDenseSpec dense_spec = {
input_shape.dims(), 0, 0, *begin, *end, *strides};

if (begin_tensor.dtype() == DT_INT32) {
BuildDenseSpec<int32>(sparse_spec, &dense_spec);
} else if (begin_tensor.dtype() == DT_INT64) {
BuildDenseSpec<int64>(sparse_spec, &dense_spec);
} else {
LOG(FATAL) << "begin must be either int32 or int64";
}

// Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
// and bounds check!
*is_identity = true;
*slice_dim0 = true;
*is_simple_slice = true;
for (int i = 0; i < dense_spec.dims; ++i) {
int64& begin_i = (*begin)[i];
int64& end_i = (*end)[i];
int64& stride_i = (*strides)[i];
int64 dim_i = input_shape.dim_size(i);
OP_REQUIRES(context, stride_i != 0,
errors::InvalidArgument("strides[", i, "] must be non-zero"));

int64 masks[] = {dense_spec.begin_mask & (1 << i),
dense_spec.end_mask & (1 << i)};
int64 valid_range[] = {stride_i > 0 ? 0 : -1,
stride_i > 0 ? dim_i : dim_i - 1};

auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) {
if (masks[c]) {
return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
} else {
int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
return x_fwd < valid_range[0]
? valid_range[0]
: x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
}
};
if (dense_spec.shrink_axis_mask & (1 << i)) {
// If we are shrinking, the end index is now possibly incorrect. In
// particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
// and canonical puts these to n-1 and 0, which implies a degenerate
// interval. Fortunately, it is now safe to re-create end as begin+1.
int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
begin_i = x_fwd;
end_i = begin_i + 1;
OP_REQUIRES(context, stride_i > 0,
errors::InvalidArgument("only stride 1 allowed on"
" non-range indexing."));
OP_REQUIRES(
context, x_fwd >= 0 && x_fwd < dim_i,
errors::InvalidArgument("slice index ", begin_i, " of dimension ", i,
" out of bounds."));
} else {
begin_i = canonical(begin_i, 0);
end_i = canonical(end_i, 1);
}
// Update optimization values
(*is_simple_slice) &= stride_i == 1;
bool take_all_in_dimension =
stride_i == 1 && begin_i == 0 && end_i == input_shape.dim_size(i);
(*is_identity) &= take_all_in_dimension;
(*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;

// Compute the processing shape (the intermediate Eigen will produce)
int64 interval_length = end_i - begin_i;
int64 size_i;
// Hold zero if the interval is degenerate, otherwise account for remainder
if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0)))
size_i = 0;
else
size_i = interval_length / stride_i +
(interval_length % stride_i != 0 ? 1 : 0);
processing_shape->AddDim(size_i);
}

// Step 4: Compute the final shape
//
// new_axis will increase dimension by 1 (with a one-size dimension)
// slices like foo[3,...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
for (auto gather_index : dense_spec.final_shape_gather_indices) {
if (gather_index >= 0)
final_shape->AddDim(processing_shape->dim_size(gather_index));
else if (gather_index == kNewAxis)
final_shape->AddDim(1);
}
}

template <typename Device, typename T>
class StridedSliceOp : public OpKernel {
public:
Expand All @@ -336,11 +60,13 @@ class StridedSliceOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;

SharedValidation(context, context->input(0).shape(), begin_mask, end_mask,
ellipsis_mask, new_axis_mask, shrink_axis_mask,
&processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides);
if (!context->status().ok()) return;
OP_REQUIRES_OK(context,
ValidateStridedSliceOp(
context->input(1), context->input(2), context->input(3),
context->input(0).shape(), begin_mask, end_mask,
ellipsis_mask, new_axis_mask, shrink_axis_mask,
&processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides));

const Tensor& input = context->input(0);

Expand Down Expand Up @@ -460,10 +186,13 @@ class StridedSliceGradOp : public OpKernel {
LOG(FATAL) << "shape must have type int32 or int64.";
}

SharedValidation(context, input_shape, begin_mask, end_mask, ellipsis_mask,
new_axis_mask, shrink_axis_mask, &processing_shape,
&final_shape, &is_identity, &is_simple_slice, &slice_dim0,
&begin, &end, &strides);
OP_REQUIRES_OK(
context,
ValidateStridedSliceOp(
context->input(1), context->input(2), context->input(3),
input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides));

// Check to make sure dy is consistent with the original slice
TensorShape dy_shape = context->input(4).shape();
Expand Down Expand Up @@ -527,11 +256,13 @@ class StridedSliceAssignOp : public OpKernel {
context->forward_ref_input_to_ref_output(0, 0);
Tensor old_lhs = context->mutable_input(0, true);

SharedValidation(context, old_lhs.shape(), begin_mask, end_mask,
ellipsis_mask, new_axis_mask, shrink_axis_mask,
&processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides);
if (!context->status().ok()) return;
OP_REQUIRES_OK(
context,
ValidateStridedSliceOp(
context->input(1), context->input(2), context->input(3),
old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides));

if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
Expand Down
Loading

0 comments on commit f51e196

Please sign in to comment.