Skip to content

Commit 8858439

Browse files
authored
[Phi] Add yaml for assign_value (#44596)
* [Phi] Add yaml for assign_value * [Phi] Fix the bug of the assign api and modify the unittest * [Phi] Fix the bug when the tensor does not have the backend info * [Phi] Replace the functional-style cast init by the brace-init * [Phi] Cast the data explicitly
1 parent 856f741 commit 8858439

File tree

13 files changed

+129
-15
lines changed

13 files changed

+129
-15
lines changed

paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
'Scalar(int64_t)' : 'paddle::experimental::Scalar',
5757
'Scalar(float)' : 'paddle::experimental::Scalar',
5858
'Scalar(double)' : 'paddle::experimental::Scalar',
59+
'Scalar[]' : 'std::vector<phi::Scalar>',
5960
'IntArray' : 'paddle::experimental::IntArray'
6061
}
6162

paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def SkipAPIGeneration(forward_api_name):
4545
"std::vector<double>": "CastPyArg2Float64s",
4646
"std::vector<std::string>": "CastPyArg2Strings",
4747
"paddle::experimental::Scalar": "CastPyArg2Scalar",
48+
"std::vector<phi::Scalar>": "CastPyArg2ScalarArray",
4849
"paddle::experimental::IntArray": "CastPyArg2IntArray",
4950
"paddle::Place": "CastPyArg2Place",
5051
"paddle::experimental::DataType": "CastPyArg2DataType",
@@ -87,6 +88,7 @@ def SkipAPIGeneration(forward_api_name):
8788
'rmsprop',
8889
'sgd_',
8990
'sgd',
91+
'assign_value_',
9092
'sparse_momentum_',
9193
'sparse_momentum',
9294
]

paddle/fluid/pybind/eager_utils.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,54 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
12531253
return paddle::experimental::Scalar(1.0);
12541254
}
12551255

1256+
std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj,
1257+
const std::string& op_type,
1258+
ssize_t arg_pos) {
1259+
if (obj == Py_None) {
1260+
PADDLE_THROW(platform::errors::InvalidArgument(
1261+
"%s(): argument (position %d) must be "
1262+
"a list of int, float, or bool, but got %s",
1263+
op_type,
1264+
arg_pos + 1,
1265+
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
1266+
}
1267+
1268+
PyTypeObject* type = obj->ob_type;
1269+
auto type_name = std::string(type->tp_name);
1270+
VLOG(1) << "type_name: " << type_name;
1271+
if (PyList_Check(obj)) {
1272+
Py_ssize_t len = PyList_Size(obj);
1273+
PyObject* item = nullptr;
1274+
item = PyList_GetItem(obj, 0);
1275+
if (PyObject_CheckFloatOrToFloat(&item)) {
1276+
std::vector<phi::Scalar> value;
1277+
for (Py_ssize_t i = 0; i < len; i++) {
1278+
item = PyList_GetItem(obj, i);
1279+
value.emplace_back(phi::Scalar{PyFloat_AsDouble(item)});
1280+
}
1281+
return value;
1282+
} else if (PyObject_CheckLongOrToLong(&item)) {
1283+
std::vector<phi::Scalar> value;
1284+
for (Py_ssize_t i = 0; i < len; i++) {
1285+
item = PyList_GetItem(obj, i);
1286+
value.emplace_back(
1287+
phi::Scalar{static_cast<int64_t>(PyLong_AsLong(item))});
1288+
}
1289+
return value;
1290+
}
1291+
} else {
1292+
PADDLE_THROW(platform::errors::InvalidArgument(
1293+
"%s(): argument (position %d) must be "
1294+
"a list of int, float, or bool, but got %s",
1295+
op_type,
1296+
arg_pos + 1,
1297+
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
1298+
}
1299+
1300+
// Fake a ScalarArray
1301+
return std::vector<phi::Scalar>({phi::Scalar(1.0)});
1302+
}
1303+
12561304
paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
12571305
const std::string& op_type,
12581306
ssize_t arg_pos) {

paddle/fluid/pybind/eager_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
191191
const std::string& op_type,
192192
ssize_t arg_pos);
193193

194+
std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj,
195+
const std::string& op_type,
196+
ssize_t arg_pos);
197+
194198
paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
195199
const std::string& op_type,
196200
ssize_t arg_pos);

paddle/phi/api/lib/kernel_dispatch.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool HasAllocation(const phi::TensorBase& t) {
5353
}
5454

5555
BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
56-
if (HasAllocation(t)) {
56+
if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) {
5757
BackendSet backend_set(phi::TransToPhiBackend(t.place()));
5858
switch (t.layout()) {
5959
case DataLayout::MKLDNN:

paddle/phi/api/yaml/generator/api_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
125125
'Scalar(int64_t)': 'const Scalar&',
126126
'Scalar(float)': 'const Scalar&',
127127
'Scalar(dobule)': 'const Scalar&',
128+
'Scalar[]': 'const std::vector<phi::Scalar>&',
128129
'int': 'int',
129130
'int32_t': 'int32_t',
130131
'int64_t': 'int64_t',
@@ -648,6 +649,10 @@ def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
648649
if 'IntArray' in self.attrs['attr_info'][param][0]:
649650
kernel_args_type_list.append('const phi::IntArray&')
650651
param = 'phi::IntArray(' + param + ')'
652+
elif 'vector<phi::Scalar>' in self.attrs['attr_info'][param][0]:
653+
kernel_args_type_list.append(
654+
'const std::vector<phi::Scalar>&')
655+
param = param
651656
elif 'Scalar' in self.attrs['attr_info'][param][0]:
652657
kernel_args_type_list.append('const phi::Scalar&')
653658
param = 'phi::Scalar(' + param + ')'

paddle/phi/api/yaml/generator/type_mapping.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
'Scalar(int)': 'const Scalar&',
3232
'Scalar(int64_t)': 'const Scalar&',
3333
'Scalar(float)': 'const Scalar&',
34+
'Scalar[]': 'const std::vector<Scalar>&',
3435
'Place': 'Place',
3536
'DataLayout': 'DataLayout',
3637
'DataType': 'DataType',
@@ -58,6 +59,7 @@
5859
'Scalar(int)': 'int',
5960
'Scalar(int64_t)': 'int64_t',
6061
'Scalar(float)': 'float',
62+
'Scalar[]': 'std::vector<Scalar>',
6163
'Place': 'int',
6264
'DataLayout': 'int',
6365
'DataType': 'int',
@@ -83,7 +85,8 @@
8385
phi_attr_types_map = attr_types_map.copy()
8486
phi_attr_types_map.update({
8587
'IntArray': 'const phi::IntArray&',
86-
'Scalar': 'const phi::Scalar&'
88+
'Scalar': 'const phi::Scalar&',
89+
'Scalar[]': 'std::vector<phi::Scalar>&'
8790
})
8891

8992
#--------------------------- phi dense tensor ---------------------------

paddle/phi/api/yaml/legacy_api.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,20 @@
246246
inplace : (output -> out)
247247
backward : assign_out__grad
248248

249+
# assgin_value
250+
- api : assign_value_
251+
args : (Tensor output, int[] shape, DataType dtype, Scalar[] values, Place place = {})
252+
output : Tensor(out)
253+
inplace: (output -> out)
254+
infer_meta :
255+
func : AssignValueInferMeta
256+
param : [shape, dtype]
257+
kernel :
258+
func : assign_value
259+
param : [shape, dtype, values]
260+
data_type : dtype
261+
backend : place > output
262+
249263
# atan
250264
- api : atan
251265
args : (Tensor x)

paddle/phi/core/infermeta_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
209209
std::vector<double>);
210210
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
211211
std::vector<std::string>);
212+
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
213+
std::vector<Scalar>);
212214

213215
template <typename... Tail>
214216
struct InferMetaFnCallHelper<MetaTensor*, Tail...> {

python/paddle/fluid/layers/tensor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,12 +690,20 @@ def assign(input, output=None):
690690
if input.size > 1024 * 1024:
691691
raise ValueError("The size of input is too big. Please consider "
692692
"saving it to file and 'load_op' to load it")
693-
if output is None:
694-
output = helper.create_variable_for_type_inference(dtype=dtype)
695-
if _non_static_mode():
693+
if in_dygraph_mode():
694+
if output is None:
695+
output = zeros(list(input.shape), dtype)
696+
_C_ops.final_state_assign_value_(output, list(input.shape), dtype,
697+
values, _current_expected_place())
698+
elif _in_legacy_dygraph():
699+
if output is None:
700+
output = core.VarBase()
696701
_C_ops.assign_value(output, 'shape', list(input.shape), 'dtype',
697702
dtype, value_name, values)
698703
else:
704+
if output is None:
705+
output = helper.create_variable_for_type_inference(
706+
dtype=input.dtype)
699707
helper.append_op(type='assign_value',
700708
outputs={'Out': [output]},
701709
attrs={

0 commit comments

Comments
 (0)