|
24 | 24 | import tvm.topi.testing
|
25 | 25 | from tvm import relay
|
26 | 26 | from tvm.contrib import graph_executor
|
| 27 | +import pytest |
27 | 28 |
|
28 | 29 | import paddle
|
29 | 30 | import paddle.nn as nn
|
@@ -127,8 +128,6 @@ def add_subtract3(inputs1, inputs2):
|
127 | 128 |
|
128 | 129 | @tvm.testing.uses_gpu
|
129 | 130 | def test_forward_arg_max_min():
|
130 |
| - input_shape = [1, 3, 10, 10] |
131 |
| - |
132 | 131 | class ArgMax(nn.Layer):
|
133 | 132 | @paddle.jit.to_static
|
134 | 133 | def forward(self, inputs):
|
@@ -169,32 +168,50 @@ class ArgMin3(nn.Layer):
|
169 | 168 | def forward(self, inputs):
|
170 | 169 | return inputs.argmin(axis=2, keepdim=True)
|
171 | 170 |
|
172 |
| - input_data = paddle.rand(input_shape, dtype="float32") |
173 |
| - verify_model(ArgMax(), input_data=input_data) |
174 |
| - verify_model(ArgMax1(), input_data=input_data) |
175 |
| - verify_model(ArgMax2(), input_data=input_data) |
176 |
| - verify_model(ArgMax3(), input_data=input_data) |
177 |
| - verify_model(ArgMin(), input_data=input_data) |
178 |
| - verify_model(ArgMin1(), input_data=input_data) |
179 |
| - verify_model(ArgMin2(), input_data=input_data) |
180 |
| - verify_model(ArgMin3(), input_data=input_data) |
| 171 | + input_shapes = [[256], [10, 128], [100, 500, 200], [1, 3, 224, 224]] |
| 172 | + for input_shape in input_shapes: |
| 173 | + input_data = paddle.rand(input_shape, dtype="float32") |
| 174 | + verify_model(ArgMax(), input_data=input_data) |
| 175 | + verify_model(ArgMin(), input_data=input_data) |
| 176 | + for input_shape in input_shapes[1:]: |
| 177 | + input_data = paddle.rand(input_shape, dtype="float32") |
| 178 | + verify_model(ArgMax1(), input_data=input_data) |
| 179 | + verify_model(ArgMax2(), input_data=input_data) |
| 180 | + verify_model(ArgMin1(), input_data=input_data) |
| 181 | + verify_model(ArgMin2(), input_data=input_data) |
| 182 | + for input_shape in input_shapes[2:]: |
| 183 | + input_data = paddle.rand(input_shape, dtype="float32") |
| 184 | + verify_model(ArgMax3(), input_data=input_data) |
| 185 | + verify_model(ArgMin3(), input_data=input_data) |
181 | 186 |
|
182 | 187 |
|
183 | 188 | @tvm.testing.uses_gpu
|
184 | 189 | def test_forward_argsort():
|
185 |
| - @paddle.jit.to_static |
186 |
| - def argsort(inputs): |
187 |
| - return paddle.argsort(inputs) |
| 190 | + class ArgSort1(nn.Layer): |
| 191 | + @paddle.jit.to_static |
| 192 | + def forward(self, inputs): |
| 193 | + return paddle.argsort(inputs) |
188 | 194 |
|
189 |
| - @paddle.jit.to_static |
190 |
| - def argsort2(inputs): |
191 |
| - return paddle.argsort(inputs, axis=0, descending=True) |
| 195 | + class ArgSort2(nn.Layer): |
| 196 | + @paddle.jit.to_static |
| 197 | + def forward(self, inputs): |
| 198 | + return paddle.argsort(inputs, axis=0, descending=True) |
192 | 199 |
|
193 |
| - input_shape = [2, 3, 5] |
194 |
| - input_data = paddle.rand(input_shape, dtype="float32") |
195 |
| - verify_model(argsort, input_data) |
196 |
| - input_data2 = np.random.randint(100, size=input_shape) |
197 |
| - verify_model(argsort2, input_data2) |
| 200 | + class ArgSort3(nn.Layer): |
| 201 | + @paddle.jit.to_static |
| 202 | + def forward(self, inputs): |
| 203 | + return paddle.argsort(inputs, axis=-1, descending=True) |
| 204 | + |
| 205 | + input_shapes = [[256], [10, 20], [10, 10, 3], [1, 3, 5, 5]] |
| 206 | + for input_shape in input_shapes: |
| 207 | + # Avoid duplicate elements in the array which will bring |
| 208 | + # different results with different sort algorithms |
| 209 | + np.random.seed(13) |
| 210 | + np_data = np.random.choice(range(-5000, 5000), np.prod(input_shape), replace=False) |
| 211 | + input_data = paddle.to_tensor(np_data.reshape(input_shape).astype("int64")) |
| 212 | + verify_model(ArgSort1(), [input_data]) |
| 213 | + verify_model(ArgSort2(), [input_data]) |
| 214 | + verify_model(ArgSort3(), [input_data]) |
198 | 215 |
|
199 | 216 |
|
200 | 217 | @tvm.testing.uses_gpu
|
@@ -291,23 +308,27 @@ def cast2(inputs, dtype="int64"):
|
291 | 308 |
|
292 | 309 | @tvm.testing.uses_gpu
|
293 | 310 | def test_forward_check_tensor():
|
294 |
| - @paddle.jit.to_static |
295 |
| - def isfinite(inputs): |
296 |
| - return paddle.cast(paddle.isfinite(inputs), "int32") |
| 311 | + class IsFinite(nn.Layer): |
| 312 | + @paddle.jit.to_static |
| 313 | + def forward(self, inputs): |
| 314 | + return paddle.cast(paddle.isfinite(inputs), "int32") |
297 | 315 |
|
298 |
| - @paddle.jit.to_static |
299 |
| - def isnan(inputs): |
300 |
| - return paddle.cast(paddle.isnan(inputs), "int32") |
| 316 | + class IsNan(nn.Layer): |
| 317 | + @paddle.jit.to_static |
| 318 | + def forward(self, inputs): |
| 319 | + return paddle.cast(paddle.isnan(inputs), "int32") |
301 | 320 |
|
302 |
| - @paddle.jit.to_static |
303 |
| - def isinf(inputs): |
304 |
| - return paddle.cast(paddle.isinf(inputs), "int32") |
| 321 | + class IsInf(nn.Layer): |
| 322 | + @paddle.jit.to_static |
| 323 | + def forward(self, inputs): |
| 324 | + return paddle.cast(paddle.isinf(inputs), "int32") |
305 | 325 |
|
306 |
| - input_shape = [5, 5] |
307 |
| - input_data = paddle.rand(input_shape, dtype="float32") |
308 |
| - verify_model(isfinite, input_data=input_data) |
309 |
| - verify_model(isnan, input_data=input_data) |
310 |
| - verify_model(isinf, input_data=input_data) |
| 326 | + input_shapes = [[32], [8, 128], [2, 128, 256], [2, 3, 224, 224], [2, 2, 3, 229, 229]] |
| 327 | + for input_shape in input_shapes: |
| 328 | + input_data = paddle.rand(input_shape, dtype="float32") |
| 329 | + verify_model(IsFinite(), input_data=input_data) |
| 330 | + verify_model(IsNan(), input_data=input_data) |
| 331 | + verify_model(IsInf(), input_data=input_data) |
311 | 332 |
|
312 | 333 |
|
313 | 334 | @tvm.testing.uses_gpu
|
@@ -391,15 +412,16 @@ def forward(self, inputs):
|
391 | 412 |
|
392 | 413 | @tvm.testing.uses_gpu
|
393 | 414 | def test_forward_dot():
|
394 |
| - @paddle.jit.to_static |
395 |
| - def dot(x, y): |
396 |
| - return paddle.dot(x, y) |
| 415 | + class Dot(nn.Layer): |
| 416 | + @paddle.jit.to_static |
| 417 | + def forward(self, x, y): |
| 418 | + return paddle.dot(x, y) |
397 | 419 |
|
398 |
| - x_shape = [10, 3] |
399 |
| - y_shape = [10, 3] |
400 |
| - x_data = paddle.rand(x_shape, dtype="float32") |
401 |
| - y_data = paddle.rand(y_shape, dtype="float32") |
402 |
| - verify_model(dot, input_data=[x_data, y_data]) |
| 420 | + input_shapes = [[128], [8, 128]] |
| 421 | + for input_shape in input_shapes: |
| 422 | + x_data = paddle.rand(input_shape, dtype="float32") |
| 423 | + y_data = paddle.rand(input_shape, dtype="float32") |
| 424 | + verify_model(Dot(), input_data=[x_data, y_data]) |
403 | 425 |
|
404 | 426 |
|
405 | 427 | @tvm.testing.uses_gpu
|
@@ -435,44 +457,70 @@ def forward(self, input1, input2):
|
435 | 457 | api_list = [
|
436 | 458 | "equal",
|
437 | 459 | ]
|
438 |
| - input_shape = [10, 10] |
439 |
| - input_shape_2 = [ |
440 |
| - 10, |
441 |
| - ] |
442 |
| - x_data = paddle.randint(1, 10, input_shape, dtype="int32") |
443 |
| - y_data = paddle.randint(1, 10, input_shape_2, dtype="int32") |
444 |
| - for api_name in api_list: |
445 |
| - verify_model(ElemwiseAPI(api_name), [x_data, y_data]) |
| 460 | + x_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]] |
| 461 | + y_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]] |
| 462 | + for x_shape, y_shape in zip(x_shapes, y_shapes): |
| 463 | + x_data = paddle.randint(1, 1000, x_shape, dtype="int32") |
| 464 | + y_data = paddle.randint(1, 1000, y_shape, dtype="int32") |
| 465 | + for api_name in api_list: |
| 466 | + verify_model(ElemwiseAPI(api_name), [x_data, y_data]) |
446 | 467 |
|
447 | 468 |
|
448 | 469 | @tvm.testing.uses_gpu
|
449 | 470 | def test_forward_expand():
|
450 | 471 | @paddle.jit.to_static
|
451 | 472 | def expand1(inputs):
|
452 |
| - return paddle.expand(inputs, shape=[2, 3]) |
| 473 | + return paddle.expand(inputs, shape=[2, 128]) |
453 | 474 |
|
454 | 475 | @paddle.jit.to_static
|
455 | 476 | def expand2(inputs):
|
456 |
| - shape = paddle.to_tensor(np.array([2, 3]).astype("int32")) |
| 477 | + return paddle.expand(inputs, shape=[3, 1, 8, 256]) |
| 478 | + |
| 479 | + @paddle.jit.to_static |
| 480 | + def expand3(inputs): |
| 481 | + return paddle.expand(inputs, shape=[5, 1, 3, 224, 224]) |
| 482 | + |
| 483 | + @paddle.jit.to_static |
| 484 | + def expand4(inputs): |
| 485 | + shape = paddle.to_tensor(np.array([2, 128]).astype("int32")) |
457 | 486 | return paddle.expand(inputs, shape=shape)
|
458 | 487 |
|
459 |
| - x_shape = [3] |
460 |
| - x_data = paddle.rand(x_shape, dtype="float32") |
461 |
| - verify_model(expand1, input_data=[x_data]) |
462 |
| - verify_model(expand2, input_data=[x_data]) |
| 488 | + @paddle.jit.to_static |
| 489 | + def expand5(inputs): |
| 490 | + shape = paddle.to_tensor(np.array([3, 1, 8, 256]).astype("int32")) |
| 491 | + return paddle.expand(inputs, shape=shape) |
| 492 | + |
| 493 | + @paddle.jit.to_static |
| 494 | + def expand6(inputs): |
| 495 | + shape = paddle.to_tensor(np.array([5, 1, 3, 224, 224]).astype("int32")) |
| 496 | + return paddle.expand(inputs, shape=shape) |
| 497 | + |
| 498 | + data = paddle.rand([128], dtype="float32") |
| 499 | + verify_model(expand1, input_data=[data]) |
| 500 | + verify_model(expand4, input_data=[data]) |
| 501 | + data = paddle.rand([8, 256], dtype="float32") |
| 502 | + verify_model(expand2, input_data=[data]) |
| 503 | + verify_model(expand5, input_data=[data]) |
| 504 | + data = paddle.rand([1, 3, 224, 224], dtype="float32") |
| 505 | + verify_model(expand3, input_data=[data]) |
| 506 | + verify_model(expand6, input_data=[data]) |
463 | 507 |
|
464 | 508 |
|
465 | 509 | @tvm.testing.uses_gpu
|
466 | 510 | def test_forward_expand_as():
|
467 |
| - @paddle.jit.to_static |
468 |
| - def expand_as(x, y): |
469 |
| - z = paddle.expand_as(x, y) |
470 |
| - z += y |
471 |
| - return z |
| 511 | + class ExpandAs(nn.Layer): |
| 512 | + @paddle.jit.to_static |
| 513 | + def forward(self, x, y): |
| 514 | + z = paddle.expand_as(x, y) |
| 515 | + z += y |
| 516 | + return z |
472 | 517 |
|
473 |
| - data_x = paddle.to_tensor([1, 2, 3], dtype="int32") |
474 |
| - data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") |
475 |
| - verify_model(expand_as, [data_x, data_y]) |
| 518 | + x_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]] |
| 519 | + y_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]] |
| 520 | + for x_shape, y_shape in zip(x_shapes, y_shapes): |
| 521 | + x_data = paddle.rand(x_shape, dtype="float32") |
| 522 | + y_data = paddle.rand(y_shape, dtype="float32") |
| 523 | + verify_model(ExpandAs(), [x_data, y_data]) |
476 | 524 |
|
477 | 525 |
|
478 | 526 | @tvm.testing.uses_gpu
|
@@ -591,11 +639,14 @@ def forward(self, x, y):
|
591 | 639 | z = self.func(x, y, out=out)
|
592 | 640 | return paddle.cast(z, "int32")
|
593 | 641 |
|
594 |
| - x = paddle.to_tensor([True]) |
595 |
| - y = paddle.to_tensor([True, False, True, False]) |
596 |
| - verify_model(LogicalAPI("logical_and"), [x, y]) |
597 |
| - verify_model(LogicalAPI("logical_or"), [x, y]) |
598 |
| - verify_model(LogicalAPI("logical_xor"), [x, y]) |
| 642 | + x_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]] |
| 643 | + y_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]] |
| 644 | + for x_shape, y_shape in zip(x_shapes, y_shapes): |
| 645 | + x_data = paddle.randint(0, 2, x_shape).astype("bool") |
| 646 | + y_data = paddle.randint(0, 2, y_shape).astype("bool") |
| 647 | + verify_model(LogicalAPI("logical_and"), [x_data, y_data]) |
| 648 | + verify_model(LogicalAPI("logical_or"), [x_data, y_data]) |
| 649 | + verify_model(LogicalAPI("logical_xor"), [x_data, y_data]) |
599 | 650 |
|
600 | 651 |
|
601 | 652 | @tvm.testing.uses_gpu
|
@@ -796,11 +847,13 @@ def forward(self, inputs):
|
796 | 847 | "relu",
|
797 | 848 | "tanh",
|
798 | 849 | ]
|
799 |
| - input_shape = [1, 3, 10, 10] |
800 |
| - input_data = paddle.rand(input_shape, dtype="float32") |
801 |
| - for api_name in api_list: |
802 |
| - verify_model(MathAPI(api_name), input_data=input_data) |
| 850 | + input_shapes = [[128], [2, 256], [1000, 128, 32], [7, 3, 256, 256]] |
| 851 | + for input_shape in input_shapes: |
| 852 | + input_data = paddle.rand(input_shape, dtype="float32") |
| 853 | + for api_name in api_list: |
| 854 | + verify_model(MathAPI(api_name), input_data=input_data) |
803 | 855 |
|
804 | 856 |
|
805 | 857 | if __name__ == "__main__":
|
806 |
| - pytest.main([__file__]) |
| 858 | + # pytest.main([__file__]) |
| 859 | + test_forward_math_api() |
0 commit comments