Skip to content

Commit 209bdda

Browse files
committed
add diagonal_scatter_test
1 parent 99ac6ee commit 209bdda

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from op_test import convert_float_to_uint16
19+
20+
import paddle
21+
from paddle import base
22+
from paddle.base import core
23+
24+
25+
def fill_diagonal_ndarray(x, value, offset=0, dim1=0, dim2=1):
26+
"""Fill value into the diagonal of x that offset is ${offset} and the coordinate system is (dim1, dim2)."""
27+
strides = x.strides
28+
shape = x.shape
29+
if dim1 > dim2:
30+
dim1, dim2 = dim2, dim1
31+
assert 0 <= dim1 < dim2 <= 2
32+
assert len(x.shape) == 3
33+
34+
dim_sum = dim1 + dim2
35+
dim3 = len(x.shape) - dim_sum
36+
if offset >= 0:
37+
diagdim = min(shape[dim1], shape[dim2] - offset)
38+
diagonal = np.lib.stride_tricks.as_strided(
39+
x[:, offset:] if dim_sum == 1 else x[:, :, offset:],
40+
shape=(shape[dim3], diagdim),
41+
strides=(strides[dim3], strides[dim1] + strides[dim2]),
42+
)
43+
else:
44+
diagdim = min(shape[dim2], shape[dim1] + offset)
45+
diagonal = np.lib.stride_tricks.as_strided(
46+
x[-offset:, :] if dim_sum in [1, 2] else x[:, -offset:],
47+
shape=(shape[dim3], diagdim),
48+
strides=(strides[dim3], strides[dim1] + strides[dim2]),
49+
)
50+
51+
diagonal[...] = value
52+
return x
53+
54+
55+
def fill_gt(x, y, offset, dim1, dim2):
56+
if dim1 > dim2:
57+
dim1, dim2 = dim2, dim1
58+
offset = -offset
59+
xshape = x.shape
60+
yshape = y.shape
61+
62+
perm_list = []
63+
unperm_list = [0] * len(xshape)
64+
idx = 0
65+
66+
for i in range(len(xshape)):
67+
if i != dim1 and i != dim2:
68+
perm_list.append(i)
69+
unperm_list[i] = idx
70+
idx += 1
71+
perm_list += [dim1, dim2]
72+
unperm_list[dim1] = idx
73+
unperm_list[dim2] = idx + 1
74+
75+
x = np.transpose(x, perm_list)
76+
y = y.reshape((-1, yshape[-1]))
77+
nxshape = x.shape
78+
x = x.reshape((-1, xshape[dim1], xshape[dim2]))
79+
out = fill_diagonal_ndarray(x, y, offset, 1, 2)
80+
81+
out = out.reshape(nxshape)
82+
out = np.transpose(out, unperm_list)
83+
return out
84+
85+
86+
class TestDiagonalScatterAPI(unittest.TestCase):
87+
def set_args(self):
88+
self.dtype = "float32"
89+
self.x = np.random.random([10, 10]).astype(np.float32)
90+
self.y = np.random.random([10]).astype(np.float32)
91+
self.offset = 0
92+
self.axis1 = 0
93+
self.axis2 = 1
94+
95+
def set_api(self):
96+
self.ref_api = fill_gt
97+
self.paddle_api = paddle.diagonal_scatter
98+
99+
def get_output(self):
100+
self.output = self.ref_api(
101+
self.x, self.y, self.offset, self.axis1, self.axis2
102+
)
103+
104+
def setUp(self):
105+
self.set_api()
106+
self.set_args()
107+
self.get_output()
108+
109+
def test_dygraph(self):
110+
paddle.disable_static()
111+
x = paddle.to_tensor(self.x, self.dtype)
112+
y = paddle.to_tensor(self.y, self.dtype)
113+
result = paddle.diagonal_scatter(
114+
x, y, offset=self.offset, axis1=self.axis1, axis2=self.axis2
115+
)
116+
np.testing.assert_allclose(self.output, result.numpy(), rtol=1e-5)
117+
paddle.enable_static()
118+
119+
def test_static(self):
120+
if self.dtype not in [
121+
"float16",
122+
"float32",
123+
"float64",
124+
"int16",
125+
"int32",
126+
"int64",
127+
"bool",
128+
"uint16",
129+
]:
130+
return
131+
paddle.enable_static()
132+
startup_program = base.Program()
133+
train_program = base.Program()
134+
with base.program_guard(startup_program, train_program):
135+
x = paddle.static.data(
136+
name="X", shape=self.x.shape, dtype=self.dtype
137+
)
138+
y = paddle.static.data(
139+
name="Y", shape=self.y.shape, dtype=self.dtype
140+
)
141+
out = paddle.diagonal_scatter(
142+
x, y, offset=self.offset, axis1=self.axis1, axis2=self.axis2
143+
)
144+
145+
place = (
146+
base.CUDAPlace(0)
147+
if core.is_compiled_with_cuda()
148+
else base.CPUPlace()
149+
)
150+
151+
exe = base.Executor(place)
152+
result = exe.run(
153+
base.default_main_program(),
154+
feed={"X": self.x, "Y": self.y},
155+
fetch_list=[out],
156+
)
157+
np.testing.assert_allclose(self.output, result[0], rtol=1e-5)
158+
paddle.disable_static()
159+
160+
161+
# check the data type of the input
162+
class TestDiagonalScatterFloat16(TestDiagonalScatterAPI):
163+
def set_args(self):
164+
self.dtype = "float16"
165+
self.x = np.random.random([10, 10]).astype(np.float16)
166+
self.y = np.random.random([10]).astype(np.float16)
167+
self.offset = 0
168+
self.axis1 = 0
169+
self.axis2 = 1
170+
171+
172+
class TestDiagonalScatterFloat64(TestDiagonalScatterAPI):
173+
def set_args(self):
174+
self.dtype = "float64"
175+
self.x = np.random.random([10, 10]).astype(np.float64)
176+
self.y = np.random.random([10]).astype(np.float64)
177+
self.offset = 0
178+
self.axis1 = 0
179+
self.axis2 = 1
180+
181+
182+
@unittest.skipIf(
183+
not core.is_compiled_with_cuda()
184+
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
185+
"core is not compiled with CUDA or not support bfloat16",
186+
)
187+
class TestDiagonalScatterBFloat16(TestDiagonalScatterAPI):
188+
def set_args(self):
189+
self.dtype = "bfloat16"
190+
self.x = convert_float_to_uint16(
191+
np.random.random([10, 10]).astype(np.float32)
192+
)
193+
self.y = convert_float_to_uint16(
194+
np.random.random([10]).astype(np.float32)
195+
)
196+
self.offset = 0
197+
self.axis1 = 0
198+
self.axis2 = 1
199+
200+
201+
class TestDiagoalScatterUInt8(TestDiagonalScatterAPI):
202+
def set_args(self):
203+
self.dtype = "uint8"
204+
self.x = np.random.randint(0, 255, [10, 10]).astype(np.uint8)
205+
self.y = np.random.randint(0, 255, [10]).astype(np.uint8)
206+
self.offset = 0
207+
self.axis1 = 0
208+
self.axis2 = 1
209+
210+
211+
class TestDiagoalScatterInt8(TestDiagonalScatterAPI):
212+
def set_args(self):
213+
self.dtype = "int8"
214+
self.x = np.random.randint(-128, 127, [10, 10]).astype(np.int8)
215+
self.y = np.random.randint(-128, 127, [10]).astype(np.int8)
216+
self.offset = 0
217+
self.axis1 = 0
218+
self.axis2 = 1
219+
220+
221+
class TestDiagoalScatterInt32(TestDiagonalScatterAPI):
222+
def set_args(self):
223+
self.dtype = "int32"
224+
self.x = np.random.randint(-2147483648, 2147483647, [10, 10]).astype(
225+
np.int32
226+
)
227+
self.y = np.random.randint(-2147483648, 2147483647, [10]).astype(
228+
np.int32
229+
)
230+
self.offset = 0
231+
self.axis1 = 0
232+
self.axis2 = 1
233+
234+
235+
class TestDiagoalScatterInt64(TestDiagonalScatterAPI):
236+
def set_args(self):
237+
self.dtype = "int64"
238+
self.x = np.random.randint(
239+
-9223372036854775808, 9223372036854775807, [10, 10]
240+
).astype(np.int64)
241+
self.y = np.random.randint(
242+
-9223372036854775808, 9223372036854775807, [10]
243+
).astype(np.int64)
244+
self.offset = 0
245+
self.axis1 = 0
246+
self.axis2 = 1
247+
248+
249+
class TestDiagoalScatterBool(TestDiagonalScatterAPI):
250+
def set_args(self):
251+
self.dtype = "bool"
252+
self.x = np.random.randint(0, 1, [10, 10]).astype(np.bool_)
253+
self.y = np.random.randint(0, 1, [10]).astype(np.bool_)
254+
self.offset = 0
255+
self.axis1 = 0
256+
self.axis2 = 1
257+
258+
259+
class TestDiagoalScatterComplex64(TestDiagonalScatterAPI):
260+
def set_args(self):
261+
self.dtype = "complex64"
262+
self.x = np.random.random([10, 10]).astype(np.float32)
263+
self.x = self.x + 1j * self.x
264+
self.y = np.random.random([10]).astype(np.float32)
265+
self.y = self.y + 1j * self.y
266+
self.offset = 0
267+
self.axis1 = 0
268+
self.axis2 = 1
269+
270+
271+
class TestDiagoalScatterComplex128(TestDiagonalScatterAPI):
272+
def set_args(self):
273+
self.dtype = "complex128"
274+
self.x = np.random.random([10, 10]).astype(np.float64)
275+
self.x = self.x + 1j * self.x
276+
self.y = np.random.random([10]).astype(np.float64)
277+
self.y = self.y + 1j * self.y
278+
self.offset = 0
279+
self.axis1 = 0
280+
self.axis2 = 1
281+
282+
283+
# check offset, axis
284+
class TestDiagoalScatterOffset(TestDiagonalScatterAPI):
285+
def set_args(self):
286+
self.dtype = "float32"
287+
self.x = np.random.random([10, 10]).astype(np.float32)
288+
self.y = np.random.random([9]).astype(np.float32)
289+
self.offset = 1
290+
self.axis1 = 0
291+
self.axis2 = 1
292+
293+
294+
class TestDiagoalScatterOffset2(TestDiagonalScatterAPI):
295+
def set_args(self):
296+
self.dtype = "float32"
297+
self.x = np.random.random([10, 10]).astype(np.float32)
298+
self.y = np.random.random([8]).astype(np.float32)
299+
self.offset = -2
300+
self.axis1 = 0
301+
self.axis2 = 1
302+
303+
304+
class TestDiagoalScatterAxis1(TestDiagonalScatterAPI):
305+
def set_args(self):
306+
self.dtype = "float32"
307+
self.x = np.random.random([10, 10]).astype(np.float32)
308+
self.y = np.random.random([10]).astype(np.float32)
309+
self.offset = 0
310+
self.axis1 = 1
311+
self.axis2 = 0
312+
313+
314+
# check error
315+
class TestDiagonalScatterError(TestDiagonalScatterAPI):
316+
def test_error_1(self):
317+
paddle.disable_static()
318+
x = paddle.to_tensor([1.0], "float32")
319+
y = paddle.to_tensor([], "float32")
320+
with self.assertRaisesRegex(
321+
AssertionError,
322+
"Tensor x must be at least 2-dimensional in diagonal_scatter",
323+
):
324+
paddle.diagonal_scatter(x, y)
325+
paddle.enable_static()
326+
327+
def test_error_2(self):
328+
# axis1 is out of range in diagonal_scatter (expected to be in range of [-2, 2), but got 1000)
329+
paddle.disable_static()
330+
x = paddle.to_tensor(self.x, self.dtype)
331+
y = paddle.to_tensor(self.y, self.dtype)
332+
axis1 = 1000
333+
with self.assertRaises(AssertionError):
334+
paddle.diagonal_scatter(x, y, self.offset, axis1, self.axis2)
335+
paddle.enable_static()
336+
337+
def test_error_3(self):
338+
# axis2 is out of range in diagonal_scatter (expected to be in range of [-2, 2), but got -1000)
339+
paddle.disable_static()
340+
x = paddle.to_tensor(self.x, self.dtype)
341+
y = paddle.to_tensor(self.y, self.dtype)
342+
axis2 = -1000
343+
with self.assertRaises(AssertionError):
344+
paddle.diagonal_scatter(x, y, self.offset, self.axis1, axis2)
345+
paddle.enable_static()
346+
347+
def test_error_4(self):
348+
# axis1 and axis2 should not be identical in diagonal_scatter, but received axis1 = 0, axis2 = 0
349+
paddle.disable_static()
350+
x = paddle.to_tensor(self.x, self.dtype)
351+
y = paddle.to_tensor(self.y, self.dtype)
352+
axis1 = 0
353+
axis2 = 0
354+
with self.assertRaises(AssertionError):
355+
paddle.diagonal_scatter(x, y, self.offset, axis1, axis2)
356+
paddle.enable_static()
357+
358+
359+
if __name__ == "__main__":
360+
unittest.main()

0 commit comments

Comments
 (0)