Skip to content

Commit

Permalink
[PIR] support matrix_norm and fix backward redundant cast (PaddlePa…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Apr 1, 2024
1 parent 31174be commit aed2d92
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def nuclear_norm(input, axis=axis, keepdim=False, name=None):
perm = _backshift_permutation(axis[0], axis[1], len(input.shape))
inv_perm = _inverse_permutation(perm)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
transposed = _C_ops.transpose(input, perm)
u, s, vh = _C_ops.svd(transposed, False)
result = _C_ops.sum(s, -1, None, keepdim)
Expand Down Expand Up @@ -754,7 +754,7 @@ def p_matrix_norm(input, porder=1.0, axis=axis, keepdim=False, name=None):
perm = _backshift_permutation(axis[0], axis[1], len(input.shape))
inv_perm = _inverse_permutation(perm)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
abs_ord = abs(porder)

max_min = _C_ops.max if porder > 0.0 else _C_ops.min
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_zero_dim_sundry_static_api_part1.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_create_parameter_var(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 0.5)

@test_with_pir_api
@prog_scope()
def test_getitem(self):
# case1: When all axis have a scalar indice, output should be a 0-d Tensor;
Expand Down Expand Up @@ -764,6 +765,7 @@ def test_inner(self):
self.assertEqual(res[2].shape, (2, 2))
self.assertEqual(res[3].shape, (2, 2))

@test_with_pir_api
@prog_scope()
def test_tensordot(self):
x = paddle.full(shape=[10], fill_value=0.25, dtype='float64')
Expand Down
70 changes: 54 additions & 16 deletions test/legacy_test/test_zero_dim_sundry_static_api_part4.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def test_det(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 3))

@test_with_pir_api
@prog_scope()
def test_dist(self):
x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32")
Expand All @@ -288,11 +289,12 @@ def test_dist(self):

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 2))
self.assertEqual(res[2].shape, (2, 2))
np.testing.assert_array_equal(res[0], np.array(2).astype(np.float32))

@test_with_pir_api
@prog_scope()
def test_linalg_norm(self):
def test_linalg_norm1(self):
# 1D input, p = fro ,axis = None, using reduceInferMeta
x_1 = paddle.arange(24, dtype="float32") - 12
x_1.stop_gradient = False
Expand All @@ -306,85 +308,120 @@ def test_linalg_norm(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (24,))

@test_with_pir_api
@prog_scope()
def test_linalg_norm2(self):
# 1D input, p = 1 ,axis = None,
# using p_norm, as_vector = True
x_2 = paddle.arange(24, dtype="float32") - 12
x_2.stop_gradient = False
out_2 = paddle.linalg.norm(x_2, p=1)
paddle.static.append_backward(out_2.sum())
((_, x_2_grad),) = paddle.static.append_backward(
out_2.sum(), parameter_list=[x_2]
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_2, x_2.grad_name])
res = self.exe.run(prog, fetch_list=[out_2, x_2_grad])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (24,))

@test_with_pir_api
@prog_scope()
def test_linalg_norm3(self):
# 1D input, p = 1 ,axis = 0,
# using p_norm, as_vector = False
x_2_p = paddle.arange(24, dtype="float32") - 12
x_2_p.stop_gradient = False
out_2_p = paddle.linalg.norm(x_2_p, p=1, axis=0)
paddle.static.append_backward(out_2_p.sum())
((_, x_2_p_grad),) = paddle.static.append_backward(
out_2_p.sum(), parameter_list=[x_2_p]
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_2_p, x_2_p.grad_name])
res = self.exe.run(prog, fetch_list=[out_2_p, x_2_p_grad])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (24,))

@test_with_pir_api
@prog_scope()
def test_linalg_norm4(self):
# 1D input, p = fro ,axis = 0,
# using p_norm, as_vector = False
x_2_fro = paddle.arange(24, dtype="float32") - 12
x_2_fro.stop_gradient = False
out_2_fro = paddle.linalg.norm(x_2_fro, p="fro", axis=0)
paddle.static.append_backward(out_2_fro.sum())
((_, x_2_fro_grad),) = paddle.static.append_backward(
out_2_fro.sum(), parameter_list=[x_2_fro]
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_2_fro, x_2_fro.grad_name])
res = self.exe.run(prog, fetch_list=[out_2_fro, x_2_fro_grad])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (24,))

@test_with_pir_api
@prog_scope()
def test_linalg_norm5(self):
# 2D input, p = 1, axis = [0, 1]
# using p_matrix_norm, depends on paddle.sum
x_3 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_3.stop_gradient = False
out_3 = paddle.linalg.norm(x_3, p=1, axis=[0, 1])
paddle.static.append_backward(out_3.sum())
((_, x_3_grad),) = paddle.static.append_backward(
out_3.sum(), parameter_list=[x_3]
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_3, x_3.grad_name])
res = self.exe.run(prog, fetch_list=[out_3, x_3_grad])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4, 6))

@test_with_pir_api
@prog_scope()
def test_linalg_norm6(self):
# 2D input, p = 1, axis = None
# using p_matrix_norm, depends on paddle.sum
x_4 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_4.stop_gradient = False
out_4 = paddle.linalg.norm(x_4)
paddle.static.append_backward(out_4.sum())
((_, x_4_grad),) = paddle.static.append_backward(
out_4.sum(), parameter_list=[x_4]
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_4, x_4.grad_name])
res = self.exe.run(prog, fetch_list=[out_4, x_4_grad])

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4, 6))

@test_with_pir_api
@prog_scope()
def test_linalg_norm7(self):
# 2D input, p = inf, axis = None
x_5 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_5.stop_gradient = False
out_5 = paddle.linalg.norm(x_5)
paddle.static.append_backward(out_5.sum())
((_, x_5_grad),) = paddle.static.append_backward(
out_5.sum(), parameter_list=[x_5]
)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_5, x_5.grad_name])
res = self.exe.run(prog, fetch_list=[out_5, x_5_grad])

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4, 6))

@test_with_pir_api
@prog_scope()
def test_linalg_norm8(self):
# 2D input, p = -inf, axis = [0, 1]
x_6 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_6.stop_gradient = False
out_6 = paddle.linalg.norm(x_6, p=-float("inf"), axis=[0, 1])
paddle.static.append_backward(out_6.sum())
((_, x_6_grad),) = paddle.static.append_backward(
out_6.sum(), parameter_list=[x_6]
)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_6, x_6.grad_name])
res = self.exe.run(prog, fetch_list=[out_6, x_6_grad])

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4, 6))
Expand Down Expand Up @@ -499,6 +536,7 @@ def test_linalg_cond(self):
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (2, 4, 4))

@test_with_pir_api
@prog_scope()
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
Expand Down

0 comments on commit aed2d92

Please sign in to comment.