Skip to content

Commit 176df91

Browse files
authored
Add some op yaml (#41173)
* add real and imag yaml * add roi_align and roi_pool yaml * add qr yaml * add psroi_pool yaml * fix bug * fix param bug of psroi_pool * fix infrt problem * fix merge bug
1 parent 7ed7c6c commit 176df91

File tree

13 files changed

+185
-18
lines changed

13 files changed

+185
-18
lines changed

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place)
165165
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool)
166166
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
167167
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
168-
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform backward_infermeta)
168+
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform)
169169
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
170170

171171
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/phi/api/lib/data_transform.h"
1919
#include "paddle/phi/api/lib/kernel_dispatch.h"
2020
#include "paddle/phi/api/lib/utils/storage.h"
21+
#include "paddle/phi/common/type_traits.h"
2122
#include "paddle/phi/core/compat/convert_utils.h"
2223
#include "paddle/phi/core/kernel_registry.h"
2324
#include "paddle/phi/core/meta_tensor.h"
@@ -716,6 +717,62 @@ std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
716717
return x_grad;
717718
}
718719

720+
Tensor imag_grad_impl(const Tensor& out_grad) {
721+
phi::KernelKey kernel_key{ParseBackend(out_grad),
722+
out_grad.layout(),
723+
phi::dtype::ToComplex(out_grad.dtype())};
724+
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
725+
"imag_grad", kernel_key);
726+
727+
VLOG(6) << "imag_grad API kernel key: " << kernel_key;
728+
VLOG(6) << "imag_grad API kernel: " << kernel;
729+
730+
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
731+
732+
auto dense_out_grad = TensorToDenseTensor(out_grad);
733+
734+
Tensor out;
735+
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
736+
phi::MetaTensor meta_out(kernel_out);
737+
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
738+
739+
using kernel_signature = void (*)(
740+
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);
741+
742+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
743+
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);
744+
745+
return out;
746+
}
747+
748+
Tensor real_grad_impl(const Tensor& out_grad) {
749+
phi::KernelKey kernel_key{ParseBackend(out_grad),
750+
out_grad.layout(),
751+
phi::dtype::ToComplex(out_grad.dtype())};
752+
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
753+
"real_grad", kernel_key);
754+
755+
VLOG(6) << "real_grad API kernel key: " << kernel_key;
756+
VLOG(6) << "real_grad API kernel: " << kernel;
757+
758+
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
759+
760+
auto dense_out_grad = TensorToDenseTensor(out_grad);
761+
762+
Tensor out;
763+
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
764+
phi::MetaTensor meta_out(kernel_out);
765+
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
766+
767+
using kernel_signature = void (*)(
768+
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);
769+
770+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
771+
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);
772+
773+
return out;
774+
}
775+
719776
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
720777
const Tensor& out_grad,
721778
int axis) {

paddle/phi/api/lib/api_custom_impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
9292
bool trainable_statistics,
9393
bool fuse_with_relu);
9494

95+
/************************ backward api impl ***************************/
96+
9597
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
9698
const Tensor& out_grad,
9799
const Scalar& axis);
98100

101+
Tensor imag_grad_impl(const Tensor& x);
102+
103+
Tensor real_grad_impl(const Tensor& x);
104+
99105
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
100106
const Tensor& out_grad,
101107
int axis);

paddle/phi/infermeta/backward.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/phi/infermeta/backward.h"
1616

17+
#include "paddle/phi/common/type_traits.h"
1718
#include "paddle/phi/kernels/funcs/axis_utils.h"
1819

1920
namespace phi {
@@ -402,6 +403,12 @@ void PsroiPoolGradInferMeta(const MetaTensor& x,
402403
dx->share_meta(x);
403404
}
404405

406+
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
407+
dx->set_dims(out_grad.dims());
408+
dx->set_dtype(dtype::ToComplex(out_grad.dtype()));
409+
dx->set_layout(out_grad.layout());
410+
}
411+
405412
void ScatterGradInferMeta(const MetaTensor& index,
406413
const MetaTensor& updates,
407414
const MetaTensor& out_grad,

paddle/phi/infermeta/backward.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ void PoolGradInferMeta(const MetaTensor& x,
174174
const std::string& padding_algorithm,
175175
MetaTensor* dx);
176176

177+
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx);
178+
177179
void ScatterGradInferMeta(const MetaTensor& index,
178180
const MetaTensor& updates,
179181
const MetaTensor& out_grad,

python/paddle/fluid/tests/unittests/test_psroi_pool_op.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def set_data(self):
9595
self.pooled_width).astype('float64')
9696
self.inputs = {
9797
'X': self.x,
98-
'ROIs': (self.rois_with_batch_id[:, 1:5], self.rois_lod)
98+
'ROIs': (self.rois_with_batch_id[:, 1:5], self.rois_lod),
99+
'RoisNum': self.boxes_num
99100
}
100101
self.attrs = {
101102
'output_channels': self.output_channels,
@@ -145,13 +146,14 @@ def make_rois(self):
145146

146147
def setUp(self):
147148
self.op_type = 'psroi_pool'
149+
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, output_channels, spatial_scale: paddle.vision.ops.psroi_pool(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale)
148150
self.set_data()
149151

150152
def test_check_output(self):
151-
self.check_output()
153+
self.check_output(check_eager=True)
152154

153155
def test_check_grad(self):
154-
self.check_grad(['X'], 'Out')
156+
self.check_grad(['X'], 'Out', check_eager=True)
155157

156158

157159
class TestPSROIPoolDynamicFunctionAPI(unittest.TestCase):

python/paddle/fluid/tests/unittests/test_real_imag_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def setUp(self):
3939
paddle.enable_static()
4040
# op test attrs
4141
self.op_type = "real"
42+
self.python_api = paddle.real
4243
self.dtype = np.float64
4344
self.init_input_output()
4445
# backward attrs
@@ -58,14 +59,15 @@ def init_grad_input_output(self):
5859
self.grad_out.shape)
5960

6061
def test_check_output(self):
61-
self.check_output()
62+
self.check_output(check_eager=True)
6263

6364
def test_check_grad(self):
6465
self.check_grad(
6566
['X'],
6667
'Out',
6768
user_defined_grads=[self.grad_x],
68-
user_defined_grad_outputs=[self.grad_out])
69+
user_defined_grad_outputs=[self.grad_out],
70+
check_eager=True)
6971

7072

7173
class TestImagOp(TestRealOp):
@@ -74,6 +76,7 @@ def setUp(self):
7476
paddle.enable_static()
7577
# op test attrs
7678
self.op_type = "imag"
79+
self.python_api = paddle.imag
7780
self.dtype = np.float64
7881
self.init_input_output()
7982
# backward attrs

python/paddle/fluid/tests/unittests/test_roi_pool_op.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import print_function
1616

17+
import paddle
1718
import unittest
1819
import numpy as np
1920
import math
@@ -32,6 +33,7 @@ def set_data(self):
3233
self.inputs = {
3334
'X': self.x,
3435
'ROIs': (self.rois[:, 1:5], self.rois_lod),
36+
'RoisNum': self.boxes_num
3537
}
3638

3739
self.attrs = {
@@ -130,16 +132,20 @@ def make_rois(self):
130132
rois.append(roi)
131133
self.rois_num = len(rois)
132134
self.rois = np.array(rois).astype("float64")
135+
self.boxes_num = np.array(
136+
[bno + 1 for bno in range(self.batch_size)]).astype('int32')
133137

134138
def setUp(self):
135139
self.op_type = "roi_pool"
140+
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, spatial_scale: paddle.vision.ops.roi_pool(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale)
141+
self.python_out_sig = ["Out"]
136142
self.set_data()
137143

138144
def test_check_output(self):
139-
self.check_output()
145+
self.check_output(check_eager=True)
140146

141147
def test_check_grad(self):
142-
self.check_grad(['X'], 'Out')
148+
self.check_grad(['X'], 'Out', check_eager=True)
143149

144150

145151
class BadInputTestRoiPool(unittest.TestCase):

python/paddle/tensor/attribute.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from ..fluid.layer_helper import LayerHelper
1919
from ..fluid.data_feeder import check_variable_and_dtype
2020

21-
# TODO: define functions to get tensor attributes
21+
# TODO: define functions to get tensor attributes
2222
from ..fluid.layers import rank # noqa: F401
2323
from ..fluid.layers import shape # noqa: F401
2424
import paddle
2525
from paddle import _C_ops
2626
from paddle.static import Variable
27+
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode
2728

2829
__all__ = []
2930

@@ -185,7 +186,9 @@ def real(x, name=None):
185186
# [[1., 2., 3.],
186187
# [4., 5., 6.]])
187188
"""
188-
if paddle.in_dynamic_mode():
189+
if in_dygraph_mode():
190+
return _C_ops.final_state_real(x)
191+
if _in_legacy_dygraph():
189192
return _C_ops.real(x)
190193

191194
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'real')
@@ -229,7 +232,9 @@ def imag(x, name=None):
229232
# [[6., 5., 4.],
230233
# [3., 2., 1.]])
231234
"""
232-
if paddle.in_dynamic_mode():
235+
if in_dygraph_mode():
236+
return _C_ops.final_state_imag(x)
237+
if _in_legacy_dygraph():
233238
return _C_ops.imag(x)
234239

235240
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'imag')

python/paddle/utils/code_gen/api.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,15 @@
802802
func : huber_loss
803803
# backward : huber_loss_grad
804804

805+
- api : imag
806+
args : (Tensor x)
807+
output : Tensor
808+
infer_meta :
809+
func : RealAndImagInferMeta
810+
kernel :
811+
func : imag
812+
backward : imag_grad
813+
805814
# increment
806815
- api : increment
807816
args : (Tensor x, float value)
@@ -1336,6 +1345,16 @@
13361345
func : prelu
13371346
backward : prelu_grad
13381347

1348+
- api : psroi_pool
1349+
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale)
1350+
output : Tensor
1351+
infer_meta :
1352+
func : PsroiPoolInferMeta
1353+
kernel :
1354+
func : psroi_pool
1355+
optional : boxes_num
1356+
backward : psroi_pool_grad
1357+
13391358
# put_along_axis
13401359
- api : put_along_axis
13411360
args : (Tensor x, Tensor index, Tensor value, int axis, str reduce)
@@ -1348,6 +1367,15 @@
13481367
data_type : x
13491368
backward : put_along_axis_grad
13501369

1370+
- api : qr
1371+
args : (Tensor x, str mode)
1372+
output : Tensor(q), Tensor(r)
1373+
infer_meta :
1374+
func : QrInferMeta
1375+
kernel :
1376+
func : qr
1377+
# backward : qr_grad
1378+
13511379
- api : randint
13521380
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
13531381
output : Tensor(out)
@@ -1372,6 +1400,15 @@
13721400
data_type : dtype
13731401
backend : place
13741402

1403+
- api : real
1404+
args : (Tensor x)
1405+
output : Tensor
1406+
infer_meta :
1407+
func : RealAndImagInferMeta
1408+
kernel :
1409+
func : real
1410+
backward : real_grad
1411+
13751412
- api : reciprocal
13761413
args : (Tensor x)
13771414
output : Tensor
@@ -1423,6 +1460,17 @@
14231460
optional : boxes_num
14241461
backward : roi_align_grad
14251462

1463+
- api : roi_pool
1464+
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale)
1465+
output : Tensor(out), Tensor(arg_max)
1466+
infer_meta :
1467+
func : RoiPoolInferMeta
1468+
kernel :
1469+
func : roi_pool
1470+
optional : boxes_num
1471+
intermediate : arg_max
1472+
backward : roi_pool_grad
1473+
14261474
- api : roll
14271475
args : (Tensor x, IntArray shifts, int64_t[] axis)
14281476
output : Tensor(out)

0 commit comments

Comments
 (0)