From cf3746fa1ffbc88d4f265a91c3ffc472e812e80f Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 2 Jan 2024 15:28:16 +0800 Subject: [PATCH] Set value with scalar (#60452) * set_value with scalar * fix ut --- paddle/fluid/pybind/eager_method.cc | 92 +++++++++------- paddle/fluid/pybind/slice_utils.h | 101 ++++++++++++++++++ .../base/dygraph/tensor_patch_methods.py | 11 +- 3 files changed, 157 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index c9b3b106597448..feaf7ccd1a2f68 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1613,12 +1613,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &use_strided_slice); // step2: Parse values - PADDLE_ENFORCE( - PyCheckTensor(value_obj), - platform::errors::InvalidArgument("The value must be a Tensor")); - + std::vector values; paddle::Tensor value_tensor = - reinterpret_cast(value_obj)->tensor; + dealWithValues(tensor, value_obj, &values, has_advanced_index); if (!has_advanced_index) { // use set_value OP if there is no advanced index @@ -1626,45 +1623,60 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, // Release gil and do tracing py::gil_scoped_release release; // use inplace set_value_ operator - if (value_tensor.initialized() && - (self->tensor.dtype() != value_tensor.dtype())) { - if (egr::Controller::Instance().GetAMPLevel() != - paddle::imperative::AmpLevel::O0) { - paddle::small_vector, - egr::kSlotSmallVectorSize> - tmps = {{self->tensor}, {value_tensor}}; - auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); - self->tensor = egr::EagerAmpAutoCast( - self->tensor.name(), self->tensor, amp_dtype, "set_value"); - value_tensor = egr::EagerAmpAutoCast( - value_tensor.name(), value_tensor, amp_dtype, "set_value"); - } + if (value_tensor.initialized()) { if (self->tensor.dtype() != value_tensor.dtype()) { - value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + paddle::small_vector, + egr::kSlotSmallVectorSize> + tmps = {{self->tensor}, {value_tensor}}; + auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); + self->tensor = egr::EagerAmpAutoCast( + self->tensor.name(), self->tensor, amp_dtype, "set_value"); + value_tensor = egr::EagerAmpAutoCast( + value_tensor.name(), value_tensor, amp_dtype, "set_value"); + } + if (self->tensor.dtype() != value_tensor.dtype()) { + value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + } } - } - // step3.1: Only basic indexing, use OP set_value. - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { - ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); - } - self->tensor = set_value_with_tensor__ad_func(self->tensor, - value_tensor, - slice_starts, - slice_ends, - slice_strides, - slice_axes, - decrease_axis, - none_axes); - if (PyCheckTensor(value_obj)) { - // pass the stop_gradient from value to tensor. - // pass stop gradient should be done after CheckInplace in - // set_value__dygraph_function. - if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && - egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { - egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + // step3.1: Only basic indexing, use OP set_value. + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); + } + self->tensor = set_value_with_tensor__ad_func(self->tensor, + value_tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes); + if (PyCheckTensor(value_obj)) { + // pass the stop_gradient from value to tensor. + // pass stop gradient should be done after CheckInplace in + // set_value__dygraph_function. + if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && + egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { + egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + } + } + } else { + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor); } + self->tensor = set_value__ad_func(self->tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes, + {1}, + values); } } else { // step3.2: Case for there are advanced indexing. diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index e60ab9406396a2..82bdcc80562c45 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -27,9 +27,11 @@ #include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -531,5 +533,104 @@ static void ParseBoolAndBroadcastIndices( } } +static paddle::Tensor dealWithValues(const paddle::Tensor& tensor, + PyObject* value_obj, + std::vector* values, + const bool trans_to_tensor) { + paddle::Tensor value_tensor; + if (PyCheckTensor(value_obj)) { + value_tensor = reinterpret_cast(value_obj)->tensor; + } else if (py::isinstance(value_obj)) { + paddle::Tensor value_tensor_tmp( + std::make_shared(), + egr::Controller::Instance().GenerateUniqueName()); + py::object value_obj_tmp(py::handle(value_obj), true); + py::object value = value_obj_tmp; + if (tensor.dtype() == phi::DataType::FLOAT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::BOOL) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When assign a numpy.np value to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, " + "float32, float64, complex64, complex128, int32 or int64, " + "please check the type of tensor.")); + } + SetTensorFromPyArray( + static_cast(value_tensor_tmp.impl().get()), + value, + tensor.place(), + false); + value_tensor = value_tensor_tmp; + } else { + py::object value_obj_tmp(py::handle(value_obj), true); + // convert the value to self data type + if (py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + PyComplex_Check(value_obj)) { + if (tensor.dtype() == phi::DataType::FLOAT32 || + tensor.dtype() == phi::DataType::FLOAT16 || + tensor.dtype() == phi::DataType::BFLOAT16) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT32 || + tensor.dtype() == phi::DataType::INT16 || + tensor.dtype() == phi::DataType::INT8 || + tensor.dtype() == phi::DataType::UINT8) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::BOOL) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + values->push_back(value_obj_tmp.cast>()); + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + values->push_back(value_obj_tmp.cast>()); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Value type error. The assign value allows " + "Tensor, numpy.ndarray, integer, float, complex or bool, " + "but received %s.", + Py_TYPE(value_obj))); + } + + if (trans_to_tensor) { + value_tensor = + full_ad_func({1}, (*values)[0], tensor.dtype(), tensor.place()); + } + } + return value_tensor; +} + } // namespace pybind } // namespace paddle diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index a6d1f90df4fa48..aed4833188d6c1 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -975,7 +975,7 @@ def __array__(self, dtype=None): array = array.astype(dtype) return array - def pre_deal_index_and_value(self, item, value=None): + def pre_deal_index(self, item): # since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor # we call this function in python level. item = list(item) if isinstance(item, tuple) else [item] @@ -985,17 +985,14 @@ def pre_deal_index_and_value(self, item, value=None): elif isinstance(slice_item, range): item[i] = paddle.to_tensor(list(slice_item)) - if value is not None and not isinstance(value, Variable): - value = paddle.to_tensor(value, dtype=self.dtype) - - return tuple(item), value + return tuple(item) def __getitem__(self, item): - item, _ = pre_deal_index_and_value(self, item) + item = pre_deal_index(self, item) return self._getitem_dygraph(item) def __setitem__(self, item, value): - item, value = pre_deal_index_and_value(self, item, value) + item = pre_deal_index(self, item) return self._setitem_dygraph(item, value) @framework.dygraph_only