Skip to content

Commit 1f4d4cd

Browse files
authored
[PIR] No.46 Migrate paddle.nn.functional.pad into pir (#57348)
1 parent 146489a commit 1f4d4cd

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
'stack',
6262
'poisson',
6363
'gumbel_softmax',
64+
'pad',
65+
'pad3d',
6466
'squeeze',
6567
'unsqueeze',
6668
'tril',
@@ -104,6 +106,8 @@
104106
'stack',
105107
'poisson',
106108
'gumbel_softmax',
109+
'pad',
110+
'pad3d',
107111
'squeeze',
108112
'unsqueeze',
109113
'tril',

python/paddle/nn/functional/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,7 +1658,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
16581658
paddings = pad
16591659
pad_value = value
16601660

1661-
if in_dynamic_mode():
1661+
if in_dynamic_or_pir_mode():
16621662
out = _C_ops.pad(x, paddings, float(pad_value))
16631663
return out
16641664

@@ -1712,7 +1712,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
17121712

17131713
unsqueezed_dim = []
17141714

1715-
if isinstance(pad, Variable):
1715+
if isinstance(pad, (Variable, pir.OpResult)):
17161716
if data_format in ["NCL", "NCHW", "NCDHW"]:
17171717
data_format = "NCDHW"
17181718
if x_dim == 3:
@@ -1756,7 +1756,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
17561756
unsqueezed_dim = [1]
17571757
x = unsqueeze(x, axis=unsqueezed_dim)
17581758

1759-
if in_dynamic_mode():
1759+
if in_dynamic_or_pir_mode():
17601760
if isinstance(pad, Variable):
17611761
pad = pad.tolist()
17621762
out = _C_ops.pad3d(x, pad, mode, value, data_format)

test/legacy_test/test_pad3d_op.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ def setUp(self):
9191
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
9292

9393
def test_check_output(self):
94-
self.check_output()
94+
self.check_output(check_new_ir=True)
9595

9696
def test_check_grad_normal(self):
97-
self.check_grad(['X'], 'Out')
97+
self.check_grad(['X'], 'Out', check_new_ir=True)
9898

9999
def get_dtype(self):
100100
return np.float64
@@ -214,10 +214,12 @@ def get_dtype(self):
214214
return np.float16
215215

216216
def test_check_output(self):
217-
self.check_output(atol=1e-3)
217+
self.check_output(atol=1e-3, check_new_ir=True)
218218

219219
def test_check_grad_normal(self):
220-
self.check_grad(['X'], 'Out', max_relative_error=1.5e-3)
220+
self.check_grad(
221+
['X'], 'Out', max_relative_error=1.5e-3, check_new_ir=True
222+
)
221223

222224
cls_name = "{}_{}".format(parent.__name__, "FP16OP")
223225
TestPad3dFp16.__name__ = cls_name
@@ -251,12 +253,12 @@ def get_dtype(self):
251253

252254
def test_check_output(self):
253255
place = core.CUDAPlace(0)
254-
self.check_output_with_place(place, atol=1e-2)
256+
self.check_output_with_place(place, atol=1e-2, check_new_ir=True)
255257

256258
def test_check_grad_normal(self):
257259
place = core.CUDAPlace(0)
258260
self.check_grad_with_place(
259-
place, ['X'], 'Out', max_relative_error=1e-2
261+
place, ['X'], 'Out', max_relative_error=1e-2, check_new_ir=True
260262
)
261263

262264
cls_name = "{}_{}".format(parent.__name__, "BF16OP")

test/legacy_test/test_pad_op.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def get_dtype(self):
5757
return np.float64
5858

5959
def test_check_output(self):
60-
self.check_output()
60+
self.check_output(check_new_ir=True)
6161

6262
def test_check_grad_normal(self):
63-
self.check_grad(['X'], 'Out', check_prim=True)
63+
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)
6464

6565
def initTestCase(self):
6666
self.shape = (16, 16)
@@ -101,7 +101,7 @@ def get_dtype(self):
101101
return np.float16
102102

103103
def test_check_grad_normal(self):
104-
self.check_grad(['X'], 'Out', check_prim=True)
104+
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)
105105

106106
cls_name = "{}_{}".format(parent.__name__, "Fp16")
107107
TestPadFp16.__name__ = cls_name
@@ -253,11 +253,13 @@ def initTestCase(self):
253253

254254
def test_check_output(self):
255255
place = core.CUDAPlace(0)
256-
self.check_output_with_place(place)
256+
self.check_output_with_place(place, check_new_ir=True)
257257

258258
def test_check_grad(self):
259259
place = core.CUDAPlace(0)
260-
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
260+
self.check_grad_with_place(
261+
place, ['X'], 'Out', check_prim=True, check_new_ir=True
262+
)
261263

262264

263265
if __name__ == '__main__':

0 commit comments

Comments
 (0)