Skip to content

Commit e397b29

Browse files
authored
【PIR API adaptor No.90,92】Migrate some ops into pir (#59801)
1 parent d307890 commit e397b29

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

python/paddle/incubate/operators/graph_reindex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from paddle import _C_ops
1616
from paddle.base.data_feeder import check_variable_and_dtype
1717
from paddle.base.layer_helper import LayerHelper
18-
from paddle.framework import in_dynamic_mode
18+
from paddle.framework import in_dynamic_or_pir_mode
1919
from paddle.utils import deprecated
2020

2121

@@ -130,7 +130,7 @@ def graph_reindex(
130130
"be None if `flag_buffer_hashtable` is True."
131131
)
132132

133-
if in_dynamic_mode():
133+
if in_dynamic_or_pir_mode():
134134
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
135135
x,
136136
neighbors,

python/paddle/vision/ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
convert_np_dtype_to_dtype_,
2727
in_dygraph_mode,
2828
in_dynamic_or_pir_mode,
29+
in_pir_mode,
2930
)
3031
from ..base.layer_helper import LayerHelper
3132
from ..framework import _current_expected_place
@@ -2144,6 +2145,27 @@ def generate_proposals(
21442145
scores, bbox_deltas, img_size, anchors, variances, *attrs
21452146
)
21462147

2148+
return rpn_rois, rpn_roi_probs, rpn_rois_num
2149+
elif in_pir_mode():
2150+
assert (
2151+
return_rois_num
2152+
), "return_rois_num should be True in PaddlePaddle inner op mode."
2153+
rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals(
2154+
scores,
2155+
bbox_deltas,
2156+
img_size,
2157+
anchors,
2158+
variances,
2159+
pre_nms_top_n,
2160+
post_nms_top_n,
2161+
nms_thresh,
2162+
min_size,
2163+
eta,
2164+
pixel_offset,
2165+
)
2166+
rpn_rois.stop_gradient = True
2167+
rpn_roi_probs.stop_gradient = True
2168+
rpn_rois_num.stop_gradient = True
21472169
return rpn_rois, rpn_roi_probs, rpn_rois_num
21482170
else:
21492171
helper = LayerHelper('generate_proposals_v2', **locals())

test/legacy_test/test_generate_proposals_v2_op.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def python_generate_proposals_v2(
5454
pixel_offset=pixel_offset,
5555
return_rois_num=return_rois_num,
5656
)
57-
return rpn_rois, rpn_roi_probs
57+
return rpn_rois, rpn_roi_probs, rpn_rois_num
5858

5959

6060
def generate_proposals_v2_in_python(
@@ -223,12 +223,11 @@ def set_data(self):
223223
}
224224

225225
def test_check_output(self):
226-
self.check_output()
226+
self.check_output(check_pir=True)
227227

228228
def setUp(self):
229229
self.op_type = "generate_proposals_v2"
230230
self.python_api = python_generate_proposals_v2
231-
self.python_out_sig = ['Out']
232231
self.set_data()
233232

234233
def init_test_params(self):

test/legacy_test/test_graph_reindex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_heter_reindex_result_v2(self):
129129
np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05)
130130
np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05)
131131

132+
@test_with_pir_api
132133
def test_reindex_result_static(self):
133134
paddle.enable_static()
134135
with paddle.static.program_guard(paddle.static.Program()):

0 commit comments

Comments
 (0)