Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set value with scalar #60452

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 52 additions & 40 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1607,58 +1607,70 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Parse values
PADDLE_ENFORCE(
PyCheckTensor(value_obj),
platform::errors::InvalidArgument("The value must be a Tensor"));

std::vector<phi::Scalar> values;
paddle::Tensor value_tensor =
reinterpret_cast<TensorObject*>(value_obj)->tensor;
dealWithValues(tensor, value_obj, &values, has_advanced_index);

if (!has_advanced_index) {
// use set_value OP if there is no advanced index

// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (value_tensor.initialized()) {
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
}
}

// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor);
}
self->tensor = set_value__ad_func(self->tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes,
{1},
values);
}
} else {
// step3.2: Case for there are advanced indexing.
Expand Down
101 changes: 101 additions & 0 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

Expand Down Expand Up @@ -534,5 +536,104 @@ static void ParseBoolAndBroadcastIndices(
}
}

static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
PyObject* value_obj,
std::vector<phi::Scalar>* values,
const bool trans_to_tensor) {
paddle::Tensor value_tensor;
if (PyCheckTensor(value_obj)) {
value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
} else if (py::isinstance<py::array>(value_obj)) {
paddle::Tensor value_tensor_tmp(
std::make_shared<phi::DenseTensor>(),
egr::Controller::Instance().GenerateUniqueName());
py::object value_obj_tmp(py::handle(value_obj), true);
py::object value = value_obj_tmp;
if (tensor.dtype() == phi::DataType::FLOAT32) {
if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::FLOAT64) {
if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::INT32) {
if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::INT64) {
if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::BOOL) {
if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj_tmp);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32 or int64, "
"please check the type of tensor."));
}
SetTensorFromPyArray(
static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
value,
tensor.place(),
false);
value_tensor = value_tensor_tmp;
} else {
py::object value_obj_tmp(py::handle(value_obj), true);
// convert the value to self data type
if (py::isinstance<py::float_>(value_obj_tmp) ||
py::isinstance<py::int_>(value_obj_tmp) ||
py::isinstance<py::bool_>(value_obj_tmp) ||
PyComplex_Check(value_obj)) {
if (tensor.dtype() == phi::DataType::FLOAT32 ||
tensor.dtype() == phi::DataType::FLOAT16 ||
tensor.dtype() == phi::DataType::BFLOAT16) {
values->push_back(value_obj_tmp.cast<float>());
} else if (tensor.dtype() == phi::DataType::FLOAT64) {
values->push_back(value_obj_tmp.cast<double>());
} else if (tensor.dtype() == phi::DataType::INT32 ||
tensor.dtype() == phi::DataType::INT16 ||
tensor.dtype() == phi::DataType::INT8 ||
tensor.dtype() == phi::DataType::UINT8) {
values->push_back(value_obj_tmp.cast<float>());
} else if (tensor.dtype() == phi::DataType::INT64) {
values->push_back(value_obj_tmp.cast<double>());
} else if (tensor.dtype() == phi::DataType::BOOL) {
values->push_back(value_obj_tmp.cast<bool>());
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {
values->push_back(value_obj_tmp.cast<std::complex<float>>());
} else if (tensor.dtype() == phi::DataType::COMPLEX128) {
values->push_back(value_obj_tmp.cast<std::complex<double>>());
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign value allows "
"Tensor, numpy.ndarray, integer, float, complex or bool, "
"but received %s.",
Py_TYPE(value_obj)));
}

if (trans_to_tensor) {
value_tensor =
full_ad_func({1}, (*values)[0], tensor.dtype(), tensor.place());
}
}
return value_tensor;
}

} // namespace pybind
} // namespace paddle
11 changes: 4 additions & 7 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def __array__(self, dtype=None):
array = array.astype(dtype)
return array

def pre_deal_index_and_value(self, item, value=None):
def pre_deal_index(self, item):
# since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor
# we call this function in python level.
item = list(item) if isinstance(item, tuple) else [item]
Expand All @@ -985,17 +985,14 @@ def pre_deal_index_and_value(self, item, value=None):
elif isinstance(slice_item, range):
item[i] = paddle.to_tensor(list(slice_item))

if value is not None and not isinstance(value, Variable):
value = paddle.to_tensor(value, dtype=self.dtype)

return tuple(item), value
return tuple(item)

def __getitem__(self, item):
item, _ = pre_deal_index_and_value(self, item)
item = pre_deal_index(self, item)
return self._getitem_dygraph(item)

def __setitem__(self, item, value):
item, value = pre_deal_index_and_value(self, item, value)
item = pre_deal_index(self, item)
return self._setitem_dygraph(item, value)

@framework.dygraph_only
Expand Down