Skip to content

Commit

Permalink
[slice] Support index is Tensor for slice in dynamic mode (PaddlePadd…
Browse files Browse the repository at this point in the history
  • Loading branch information
liym27 authored Apr 25, 2021
1 parent 25e723e commit aceec7f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ void BindImperative(py::module *m_ptr) {
// inplace operator for the VarBase self.
self->BumpInplaceVersion();
})
.def("__getitem__",
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, infer_flags;
Expand Down
30 changes: 28 additions & 2 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .. import framework
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
Expand Down Expand Up @@ -437,6 +437,31 @@ def __bool__(self):
def __array__(self, dtype=None):
return self.numpy().astype(dtype)

def __getitem__(self, item):
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False

if contain_tensor(item):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)

else:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)

for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value),
Expand All @@ -445,7 +470,8 @@ def __array__(self, dtype=None):
("gradient", gradient), ("register_hook", register_hook),
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__)):
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__)):
setattr(core.VarBase, method_name, method)

# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,70 @@ def _test_slice(self):
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))

def _test_slice_for_tensor_attr(self):
tensor_array = np.array(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]).astype('float32')

var = paddle.to_tensor(tensor_array)

one = paddle.ones(shape=[1], dtype="int32")
two = paddle.full(shape=[1], fill_value=2, dtype="int32")
negative_one = paddle.full(shape=[1], fill_value=-1, dtype="int32")
four = paddle.full(shape=[1], fill_value=4, dtype="int32")

var = fluid.dygraph.to_variable(tensor_array)
var1 = var[0, one, one]
var2 = var[one:]
var3 = var[0:one]
var4 = var[::negative_one]
var5 = var[one, one:, one:]
var_reshape = fluid.layers.reshape(var, [3, negative_one, 3])
var6 = var_reshape[:, :, negative_one]
var7 = var[:, :, :negative_one]
var8 = var[:one, :one, :1]
var9 = var[:-1, :negative_one, :negative_one]
var10 = var[::negative_one, :one, :negative_one]
var11 = var[:negative_one, ::-1, negative_one:]
var12 = var[one:2, 2:, ::negative_one]
var13 = var[two:10, 2:, -2:negative_one]
var14 = var[1:negative_one, 0:2, ::negative_one]
var15 = var[::negative_one, ::-1, ::negative_one]
var16 = var[-4:4]

vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15, var16
]
local_out = [var.numpy() for var in vars]

self.assertTrue(np.array_equal(local_out[1], tensor_array[0, 1, 1:2]))
self.assertTrue(np.array_equal(local_out[2], tensor_array[1:]))
self.assertTrue(np.array_equal(local_out[3], tensor_array[0:1]))
self.assertTrue(np.array_equal(local_out[4], tensor_array[::-1]))
self.assertTrue(np.array_equal(local_out[5], tensor_array[1, 1:, 1:]))
self.assertTrue(
np.array_equal(local_out[6],
tensor_array.reshape((3, -1, 3))[:, :, -1]))
self.assertTrue(np.array_equal(local_out[7], tensor_array[:, :, :-1]))
self.assertTrue(np.array_equal(local_out[8], tensor_array[:1, :1, :1]))
self.assertTrue(
np.array_equal(local_out[9], tensor_array[:-1, :-1, :-1]))
self.assertTrue(
np.array_equal(local_out[10], tensor_array[::-1, :1, :-1]))
self.assertTrue(
np.array_equal(local_out[11], tensor_array[:-1, ::-1, -1:]))
self.assertTrue(
np.array_equal(local_out[12], tensor_array[1:2, 2:, ::-1]))
self.assertTrue(
np.array_equal(local_out[13], tensor_array[2:10, 2:, -2:-1]))
self.assertTrue(
np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1]))
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))

def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
Expand All @@ -483,6 +547,7 @@ def _test_for_var(self):
def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
self._test_slice_for_tensor_attr()
self._test_for_var()

var = fluid.dygraph.to_variable(self.array)
Expand Down

0 comments on commit aceec7f

Please sign in to comment.