Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,33 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__setitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
auto self_numpy = TensorToPyArray(*self_tensor);

if (py::isinstance<py::array>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::float_>(value_obj)) {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);

} else {
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);

self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
}
})
.def("__getitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends,
Expand Down Expand Up @@ -811,7 +838,8 @@ void BindImperative(py::module *m_ptr) {
return framework::vectorize<int>(
self.Var().Get<framework::SelectedRows>().value().dims());
} else {
VLOG(2) << "It is meaningless to get shape of variable type "
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
Expand Down
55 changes: 52 additions & 3 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from __future__ import print_function

import unittest
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
import numpy as np
import six

import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import numpy as np
import paddle.fluid.layers as layers
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode


class TestVarBase(unittest.TestCase):
Expand Down Expand Up @@ -403,5 +405,52 @@ def _assert_to_static(self, var_base, static_var, is_param=False):
self.assertListEqual(list(var_base.shape), list(static_var.shape))


class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.np_value = np.random.random((2, 3)).astype(np.float32)
self.tensor_value = paddle.to_tensor(self.np_value)

def _test(self, value):
paddle.disable_static()
id_origin = id(self.tensor_x)

self.tensor_x[0] = value

if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(np.float32) + value

else:
result = self.np_value

self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

self.tensor_x[1:2] = value
self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

self.tensor_x[...] = value
self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

def test_value_tensor(self):
paddle.disable_static()
self._test(self.tensor_value)

def test_value_numpy(self):
paddle.disable_static()
self._test(self.np_value)

def test_value_int(self):
paddle.disable_static()
self._test(10)

def test_value_float(self):
paddle.disable_static()
self._test(3.3)


if __name__ == '__main__':
unittest.main()