Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize tensor getitem. #5433

Merged
merged 10 commits into from
Jul 12, 2021
77 changes: 77 additions & 0 deletions oneflow/api/python/functional/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/api/python/functional/common.h"

namespace oneflow {
namespace one {
namespace functional {

namespace detail {

Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step) {
PySliceObject* obj = (PySliceObject*)object;
if (obj->step == Py_None) {
*step = 1;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->step, step))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
CHECK_NE_OR_RETURN(*step, 0) << "slice step cannot be zero.";
if (*step < -PY_SSIZE_T_MAX) *step = -PY_SSIZE_T_MAX;
}

if (obj->start == Py_None) {
*start = *step < 0 ? PY_SSIZE_T_MAX : 0;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->start, start))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
}

if (obj->stop == Py_None) {
*stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->stop, stop))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
}
return Maybe<void>::Ok();
}

const char* PyStringAsString(PyObject* object) {
return PyBytes_AsString(PyUnicode_AsEncodedString(object, "utf-8", "~E~"));
}

Maybe<detail::IndexItem> UnpackIndexItem(PyObject* object) {
if (object == Py_Ellipsis) {
return std::make_shared<detail::IndexItem>(detail::EllipsisIndex{});
} else if (PySlice_Check(object)) {
Py_ssize_t start, end, step;
JUST(PySliceUnpack(object, &start, &end, &step));
return std::make_shared<detail::IndexItem>(start, end, step);
} else if (PyLong_Check(object) && object != Py_False && object != Py_True) {
return std::make_shared<detail::IndexItem>(static_cast<int64_t>(PyLong_AsLongLong(object)));
} else if (object == Py_False || object == Py_True) {
return std::make_shared<detail::IndexItem>(object == Py_True);
} else if (object == Py_None) {
return std::make_shared<detail::IndexItem>(detail::NoneIndex{});
}
UNIMPLEMENTED_THEN_RETURN() << "Invalid index " << PyStringAsString(PyObject_Repr(object));
}

} // namespace detail

} // namespace functional
} // namespace one
} // namespace oneflow
7 changes: 7 additions & 0 deletions oneflow/api/python/functional/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ limitations under the License.
#include <vector>
#include <pybind11/pybind11.h>

#include "oneflow/api/python/framework/throw.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/tensor_index.h"

namespace py = pybind11;

Expand Down Expand Up @@ -130,6 +132,11 @@ template<typename T>
return values;
}

Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step);
const char* PyStringAsString(PyObject* object);

Maybe<detail::IndexItem> UnpackIndexItem(PyObject* object);

} // namespace detail

} // namespace functional
Expand Down
33 changes: 33 additions & 0 deletions oneflow/api/python/functional/python_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_attr.cfg.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/scalar.h"
#include "oneflow/core/functional/tensor_index.h"

namespace py = pybind11;

Expand Down Expand Up @@ -167,6 +168,38 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const {
return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow()));
}

template<>
Maybe<TensorIndex> PythonArg::ObjectAs<TensorIndex>() const {
auto tensor_index = std::make_shared<TensorIndex>();
if (object_ == Py_Ellipsis) {
detail::IndexItem index(detail::EllipsisIndex{});
tensor_index->emplace_back(index);
} else if (PySlice_Check(object_)) {
Py_ssize_t start, end, step;
JUST(detail::PySliceUnpack(object_, &start, &end, &step));
detail::IndexItem index(start, end, step);
tensor_index->emplace_back(index);
} else if (PyLong_Check(object_) && object_ != Py_False && object_ != Py_True) {
detail::IndexItem index(static_cast<int64_t>(PyLong_AsLongLong(object_)));
tensor_index->emplace_back(index);
} else if (object_ == Py_False || object_ == Py_True) {
detail::IndexItem index(object_ == Py_True);
tensor_index->emplace_back(index);
} else if (object_ == Py_None) {
detail::IndexItem index(detail::NoneIndex{});
tensor_index->emplace_back(index);
} else {
PyObject* tuple = PySequence_Tuple(object_);
size_t size = PyTuple_GET_SIZE(tuple);
tensor_index->resize(size);
for (size_t i = 0; i < size; ++i) {
PyObject* obj = PyTuple_GET_ITEM(tuple, i);
tensor_index->at(i) = *JUST(detail::UnpackIndexItem(obj));
}
}
return tensor_index;
}

} // namespace functional
} // namespace one
} // namespace oneflow
6 changes: 5 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# {
# "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool",
# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList",
# "BoolList", "DataType", "Shape"
# "BoolList", "DataType", "Shape", "Generator", "TensorIndex"
# }

- name: "add_n"
Expand Down Expand Up @@ -593,3 +593,7 @@
- name: "pad_grad"
signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)"
bind_python: False

- name: "tensor_getitem"
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True
62 changes: 62 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
Expand Down Expand Up @@ -457,6 +458,66 @@ class TriuFunctor {
std::shared_ptr<OpExpr> op_;
};

class TensorGetItemFunctor {
public:
TensorGetItemFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const {
const auto& regular_index = JUST(RegularTensorIndex(index, *(x->shape())));
int64_t ndims = x->shape()->NumAxes();
CHECK_GE_OR_RETURN(regular_index->size(), ndims) << "Tensor index failed to be regularlized.";
std::vector<int64_t> start(ndims), end(ndims), step(ndims);
int dim = 0;
DimVector result_dims;
for (int i = 0; i < regular_index->size(); ++i) {
const auto& index_item = regular_index->at(i);
CHECK_OR_RETURN(!index_item.IsEllipsis())
<< "Tensor index should not have ellipsis once regularlized.";
if (index_item.IsSlice()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.slice().start();
end[dim] = index_item.slice().end();
step[dim] = index_item.slice().step();
int64_t length = (end[dim] - start[dim] + step[dim] - 1) / step[dim];
result_dims.emplace_back(length);
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.integer();
end[dim] = start[dim] + 1;
step[dim] = 1;
dim++;
} else if (index_item.IsNone()) {
result_dims.emplace_back(1);
} else if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
result_dims.emplace_back(1);
}
}
CHECK_EQ_OR_RETURN(dim, ndims)
<< "Specified dims count for regularlized tensor index should equal to tensor dimension "
<< ndims;

bool is_identity = [&]() {
for (int i = 0; i < ndims; ++i) {
if (start[i] != 0 || end[i] != x->shape()->At(i) || step[i] != 1) { return false; }
}
return true;
}();
std::shared_ptr<one::Tensor> result;
if (is_identity) {
result = JUST(functional::Copy(x, JUST(x->device())->type(), JUST(x->device())->device_id()));
} else {
result = JUST(functional::Slice(x, start, end, step));
}

Shape shape(result_dims);
if (shape.NumAxes() != 0 && shape != *(result->shape())) {
return functional::Reshape(result, shape);
}
return result;
}
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -484,6 +545,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::UpsampleFunctor>("Upsample");
m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
m.add_functor<impl::TriuFunctor>("Triu");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
};

} // namespace functional
Expand Down
87 changes: 87 additions & 0 deletions oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/functional/tensor_index.h"

namespace oneflow {
namespace one {
namespace functional {

int64_t CountSpecifiedDims(const TensorIndex& index) {
int64_t specified_ndims = 0;
for (int i = 0; i < index.size(); ++i) {
const auto& index_item = index.at(i);
if (index_item.IsSlice() || index_item.IsInteger()) { specified_ndims++; }
}
return specified_ndims;
}

Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape) {
int64_t specified_ndims = CountSpecifiedDims(index);
int64_t ndims = shape.NumAxes();
CHECK_LE_OR_RETURN(specified_ndims, ndims)
<< "Too many indices for tensor of dimension " << ndims;

auto regular_index = std::make_shared<TensorIndex>();
int64_t dim = 0;
for (int i = 0; i < index.size(); ++i) {
const auto& index_item = index.at(i);
if (index_item.IsSlice()) {
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
CHECK_GT_OR_RETURN(shape.At(dim), 0) << "Slice cannot be applied to a 0-dim tensor.";
const auto& slice = index_item.slice();
int64_t step = std::min(slice.step(), shape.At(dim));
CHECK_GT_OR_RETURN(step, 0) << "Step must be greater than zero.";
int64_t end = std::min(slice.end(), shape.At(dim));
int64_t start = std::min(slice.start(), shape.At(dim));
if (start < 0) { start += shape.At(dim); }
if (start < 0) { start = 0; }
if (end < 0) { end += shape.At(dim); }
if (end < start) { end = start; }
regular_index->emplace_back(detail::IndexItem(start, end, step));
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
int64_t integer = index_item.integer();
if (integer < 0) { integer += shape.At(dim); }
CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim))
<< "Index " << index_item.integer() << " is out of bounds for dimension " << dim
<< " with size " << shape.At(dim);
regular_index->emplace_back(detail::IndexItem(integer));
dim++;
} else if (index_item.IsEllipsis()) {
int64_t unspecified_ndims = ndims - specified_ndims;
unspecified_ndims = std::min(ndims - dim, unspecified_ndims);
for (int j = 0; j < unspecified_ndims; ++j) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(dim + j), 1));
}
dim += unspecified_ndims;
} else {
// None or Boolean.
if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
}
regular_index->emplace_back(index_item);
}
}
for (int i = dim; i < ndims; ++i) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(i), 1));
}
return regular_index;
}

} // namespace functional
} // namespace one
} // namespace oneflow
Loading