Skip to content

Commit

Permalink
【PIR API adaptor No.90,92】Migrate some ops into pir (#59801)
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 authored Jan 4, 2024
1 parent d307890 commit e397b29
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/paddle/incubate/operators/graph_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode
from paddle.utils import deprecated


Expand Down Expand Up @@ -130,7 +130,7 @@ def graph_reindex(
"be None if `flag_buffer_hashtable` is True."
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
x,
neighbors,
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
convert_np_dtype_to_dtype_,
in_dygraph_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from ..base.layer_helper import LayerHelper
from ..framework import _current_expected_place
Expand Down Expand Up @@ -2144,6 +2145,27 @@ def generate_proposals(
scores, bbox_deltas, img_size, anchors, variances, *attrs
)

return rpn_rois, rpn_roi_probs, rpn_rois_num
elif in_pir_mode():
assert (
return_rois_num
), "return_rois_num should be True in PaddlePaddle inner op mode."
rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals(
scores,
bbox_deltas,
img_size,
anchors,
variances,
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
min_size,
eta,
pixel_offset,
)
rpn_rois.stop_gradient = True
rpn_roi_probs.stop_gradient = True
rpn_rois_num.stop_gradient = True
return rpn_rois, rpn_roi_probs, rpn_rois_num
else:
helper = LayerHelper('generate_proposals_v2', **locals())
Expand Down
5 changes: 2 additions & 3 deletions test/legacy_test/test_generate_proposals_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def python_generate_proposals_v2(
pixel_offset=pixel_offset,
return_rois_num=return_rois_num,
)
return rpn_rois, rpn_roi_probs
return rpn_rois, rpn_roi_probs, rpn_rois_num


def generate_proposals_v2_in_python(
Expand Down Expand Up @@ -223,12 +223,11 @@ def set_data(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def setUp(self):
self.op_type = "generate_proposals_v2"
self.python_api = python_generate_proposals_v2
self.python_out_sig = ['Out']
self.set_data()

def init_test_params(self):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_graph_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_heter_reindex_result_v2(self):
np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05)
np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05)

@test_with_pir_api
def test_reindex_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down

0 comments on commit e397b29

Please sign in to comment.