Skip to content

Commit

Permalink
slice operator supporting arbitrary values of step (#8558)
Browse files Browse the repository at this point in the history
* Implement slice op forward supporting abitrary step value

Implement slice op backward

Change parallelization approach for slicing

Add unit test

Fix lint and typo

Fix doc

Fix doc

Remove slice_v1

Address cr

Remove slice_v1 in .cu

Change step data type

Add fallback for slicing csr with non-trivial step

* Add error handling for take op infer shape
  • Loading branch information
reminisce authored and piiswrong committed Nov 8, 2017
1 parent bf2336c commit 70b68b1
Show file tree
Hide file tree
Showing 9 changed files with 418 additions and 122 deletions.
75 changes: 75 additions & 0 deletions src/common/static_array.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file static_array.h
*/
#ifndef MXNET_COMMON_STATIC_ARRAY_H_
#define MXNET_COMMON_STATIC_ARRAY_H_

#include <mshadow/base.h>

namespace mxnet {
namespace common {

/*! \brief
* Static array. This code is borrowed from struct Shape<ndim>,
* except that users can specify the type of the elements of
* the statically allocated array.
* The object instance of the struct is copyable between CPU and GPU.
* \tparam T element type of the array, must be copyable between CPU and GPU
* \tparam num number of elements in the array
*/
template<typename T, int num>
struct StaticArray {
static const int kNum = num;

T array_[kNum];

/*! \brief default constructor, do nothing */
MSHADOW_XINLINE StaticArray(void) {}

/*! \brief constructor, fill in the array with the input value */
MSHADOW_XINLINE StaticArray(const T& val) {
#pragma unroll
for (int i = 0; i < num; ++i) {
this->array_[i] = val;
}
}

/*! \brief constuctor */
MSHADOW_XINLINE StaticArray(const StaticArray<T, num>& sa) {
#pragma unroll
for (int i = 0; i < num; ++i) {
this->array_[i] = sa[i];
}
}

MSHADOW_XINLINE T& operator[](const index_t idx) {
return array_[idx];
}

MSHADOW_XINLINE const T& operator[](const index_t idx) const {
return array_[idx];
}
}; // StaticArray

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_STATIC_ARRAY_H_
22 changes: 22 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ inline int get_num_threads<cpu>(const int N) {
}


#define MXNET_NDIM_SWITCH(NDim, ndim, ...) \
if (NDim == 0) { \
} else if (NDim == 1) { \
const int ndim = 1; \
{__VA_ARGS__} \
} else if (NDim == 2) { \
const int ndim = 2; \
{__VA_ARGS__} \
} else if (NDim == 3) { \
const int ndim = 3; \
{__VA_ARGS__} \
} else if (NDim == 4) { \
const int ndim = 4; \
{__VA_ARGS__} \
} else if (NDim == 5) { \
const int ndim = 5; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ndim=" << NDim << "too large "; \
}


/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
scatter_nd(data, indices) = [[0, 0], [2, 3]]
shape = (2, 2)
scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]
)code")
.set_num_outputs(1)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
const TShape &arrshape = (*in_attrs)[take_::kArr];
const TShape &idxshape = (*in_attrs)[take_::kIdx];
if (idxshape.ndim() == 0) return false;
if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false;

out_attrs->clear();

Expand Down
Loading

0 comments on commit 70b68b1

Please sign in to comment.