Skip to content

Commit a0849a9

Browse files
authored
Merge pull request apache#27 from wjj19950828/paddle_frontend
add 5 paddle ops
2 parents 2d2217b + 0ce963a commit a0849a9

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

python/tvm/relay/frontend/paddlepaddle.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,17 @@ def convert_hard_swish(g, op, block):
706706
g.add_node(op.output("Out")[0], out)
707707

708708

709+
def convert_index_select(g, op, block):
710+
"""Operator converter for index_select."""
711+
712+
dim = op.attr("dim")
713+
x = g.get_node(op.input("X")[0])
714+
index = g.get_node(op.input("Index")[0])
715+
out = _op.take(x, indices=index, axis=dim, mode="clip")
716+
717+
g.add_node(op.output("Out")[0], out)
718+
719+
709720
def convert_layer_norm(g, op, block):
710721
"""Operator converter for layer_norm."""
711722

@@ -1034,6 +1045,33 @@ def convert_pow(g, op, block):
10341045
g.add_node(op.output("Out")[0], out)
10351046

10361047

1048+
def convert_norm(g, op, block):
1049+
"""Operator converter for norm."""
1050+
1051+
x = g.get_node(op.input("X")[0])
1052+
dtype = infer_type(x).checked_type.dtype
1053+
axis = op.attr("axis")
1054+
keepdim = op.attr("keepdim")
1055+
if op.attr("asvector"):
1056+
axis = None
1057+
order = op.attr("porder")
1058+
if order == np.inf:
1059+
out = _op.reduce.max(_op.abs(x), axis=axis, keepdims=keepdim)
1060+
elif order == np.NINF:
1061+
out = _op.reduce.min(_op.abs(x), axis=axis, keepdims=keepdim)
1062+
else:
1063+
reci_order = _expr.const(1.0 / order, dtype=dtype)
1064+
order = _expr.const(order)
1065+
out = _op.power(
1066+
_op.reduce.sum(_op.power(_op.abs(x), order), axis=axis, keepdims=keepdim),
1067+
reci_order,
1068+
)
1069+
if op.attr("asvector") and not keepdim:
1070+
out = _op.expand_dims(out, axis=0)
1071+
1072+
g.add_node(op.output("Out")[0], out)
1073+
1074+
10371075
def convert_range(g, op, block):
10381076
"""Operator converter for range."""
10391077

@@ -1534,6 +1572,7 @@ def convert_unsqueeze(g, op, block):
15341572
"greater_than": convert_elementwise_op,
15351573
"hard_sigmoid": convert_hard_sigmoid,
15361574
"hard_swish": convert_hard_swish,
1575+
"index_select": convert_index_select,
15371576
"isinf": convert_unary_op,
15381577
"isinf_v2": convert_unary_op,
15391578
"layer_norm": convert_layer_norm,
@@ -1543,6 +1582,7 @@ def convert_unsqueeze(g, op, block):
15431582
"lookup_table": convert_lookup_table,
15441583
"lookup_table_v2": convert_lookup_table,
15451584
"log": convert_unary_op,
1585+
"log2": convert_unary_op,
15461586
"log10": convert_unary_op,
15471587
"log1p": convert_log1p,
15481588
"logsumexp": convert_logsumexp,
@@ -1556,6 +1596,7 @@ def convert_unsqueeze(g, op, block):
15561596
"pad2d": convert_padding,
15571597
"pad3d": convert_padding,
15581598
"pow": convert_pow,
1599+
"p_norm": convert_norm,
15591600
"range": convert_range,
15601601
"reduce_all": convert_reduce,
15611602
"reduce_any": convert_reduce,

tests/python/frontend/paddlepaddle/test_forward.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def forward(self, inputs):
156156
"exp",
157157
"floor",
158158
"log",
159+
"log2",
159160
"log10",
160161
"log1p",
161162
"relu",
@@ -627,6 +628,26 @@ def ones_like2(inputs):
627628

628629

629630
@tvm.testing.uses_gpu
631+
def test_forward_ones():
632+
@paddle.jit.to_static
633+
def ones1(inputs):
634+
ones = paddle.ones([1, 3, 10, 10])
635+
out = inputs + ones
636+
return out
637+
638+
@paddle.jit.to_static
639+
def ones2(inputs):
640+
shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32")
641+
ones = paddle.ones(shape)
642+
out = inputs + ones
643+
return out
644+
645+
input_shape = [1, 3, 10, 10]
646+
input_data = paddle.rand(input_shape, dtype="float32")
647+
verify_model(ones1, input_data=input_data)
648+
verify_model(ones2, input_data=input_data)
649+
650+
630651
def test_forward_elemwise():
631652
class ElemwiseOp(nn.Layer):
632653
def __init__(self, op_name):
@@ -733,6 +754,23 @@ def hard_swish(inputs):
733754
verify_model(hard_swish, input_data=input_data)
734755

735756

757+
@tvm.testing.uses_gpu
758+
def test_forward_index_select():
759+
@paddle.jit.to_static
760+
def index_select1(x, index):
761+
return paddle.index_select(x, index)
762+
763+
@paddle.jit.to_static
764+
def index_select2(x, index):
765+
return paddle.index_select(x, index, axis=1)
766+
767+
input_shape = [3, 10]
768+
input_data = paddle.rand(input_shape, dtype="float32")
769+
index = paddle.to_tensor(np.array([0, 1, 1]).astype('int32'))
770+
verify_model(index_select1, input_data=[input_data, index])
771+
verify_model(index_select2, input_data=[input_data, index])
772+
773+
736774
@tvm.testing.uses_gpu
737775
def test_forward_isinf():
738776
@paddle.jit.to_static
@@ -911,6 +949,97 @@ def forward(self, input1, input2):
911949
verify_model(MatMul1(), input_data=[input_data1, input_data2])
912950

913951

952+
@tvm.testing.uses_gpu
953+
def test_forward_norm():
954+
class Norm1(nn.Layer):
955+
@paddle.jit.to_static
956+
def forward(self, inputs):
957+
return paddle.norm(inputs, p=float("inf"), axis=None, keepdim=False)
958+
959+
class Norm2(nn.Layer):
960+
@paddle.jit.to_static
961+
def forward(self, inputs):
962+
return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=False)
963+
964+
class Norm3(nn.Layer):
965+
@paddle.jit.to_static
966+
def forward(self, inputs):
967+
return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=True)
968+
969+
class Norm4(nn.Layer):
970+
@paddle.jit.to_static
971+
def forward(self, inputs):
972+
return paddle.norm(inputs, p=float("inf"), axis=[1, 2], keepdim=False)
973+
974+
class Norm5(nn.Layer):
975+
@paddle.jit.to_static
976+
def forward(self, inputs):
977+
return paddle.norm(inputs, p=float("inf"), axis=-1, keepdim=True)
978+
979+
class Norm6(nn.Layer):
980+
@paddle.jit.to_static
981+
def forward(self, inputs):
982+
return paddle.norm(inputs, p=float(0.5), axis=1, keepdim=True)
983+
984+
class Norm7(nn.Layer):
985+
@paddle.jit.to_static
986+
def forward(self, inputs):
987+
return paddle.norm(inputs, p=float(1), axis=None, keepdim=False)
988+
989+
class Norm8(nn.Layer):
990+
@paddle.jit.to_static
991+
def forward(self, inputs):
992+
return paddle.norm(inputs, p=float(2.0), axis=1, keepdim=False)
993+
994+
class Norm9(nn.Layer):
995+
@paddle.jit.to_static
996+
def forward(self, inputs):
997+
return paddle.norm(inputs, p=float(-0.5), axis=[1, 2], keepdim=False)
998+
999+
class Norm10(nn.Layer):
1000+
@paddle.jit.to_static
1001+
def forward(self, inputs):
1002+
return paddle.norm(inputs, p=float(-2), axis=(1), keepdim=False)
1003+
1004+
input_shape = [1, 3, 10, 10]
1005+
input_data = paddle.rand(input_shape, dtype="float32")
1006+
verify_model(Norm1(), input_data=input_data)
1007+
verify_model(Norm2(), input_data=input_data)
1008+
verify_model(Norm3(), input_data=input_data)
1009+
verify_model(Norm4(), input_data=input_data)
1010+
verify_model(Norm5(), input_data=input_data)
1011+
verify_model(Norm6(), input_data=input_data)
1012+
verify_model(Norm7(), input_data=input_data)
1013+
verify_model(Norm8(), input_data=input_data)
1014+
verify_model(Norm9(), input_data=input_data)
1015+
verify_model(Norm10(), input_data=input_data)
1016+
1017+
1018+
@tvm.testing.uses_gpu
1019+
def test_forward_not_equal():
1020+
class Not_equal(nn.Layer):
1021+
@paddle.jit.to_static
1022+
def forward(self, x, y):
1023+
output = paddle.not_equal(x, y)
1024+
output = paddle.cast(output, "int32")
1025+
return output
1026+
1027+
x_shape = [10]
1028+
y_shape = [10]
1029+
x_data = paddle.randint(1, 10, x_shape, dtype="int32")
1030+
y_data = paddle.randint(1, 10, y_shape, dtype="int32")
1031+
x_data_1 = paddle.randint(1, 10, x_shape, dtype="int64")
1032+
y_data_1 = paddle.randint(1, 10, y_shape, dtype="int64")
1033+
verify_model(Not_equal(), input_data=[x_data, y_data])
1034+
verify_model(Not_equal(), input_data=[x_data_1, y_data_1])
1035+
# For broadcast
1036+
x_shape_1 = [10]
1037+
y_shape_1 = [10, 1]
1038+
x_data_2 = paddle.rand(x_shape_1, dtype="float32")
1039+
y_data_2 = paddle.rand(y_shape_1, dtype="float32")
1040+
verify_model(Not_equal(), input_data=[x_data_2, y_data_2])
1041+
1042+
9141043
@tvm.testing.uses_gpu
9151044
def test_forward_pool2d():
9161045
@paddle.jit.to_static
@@ -1287,6 +1416,27 @@ def tile3(inputs, inputs2):
12871416
verify_model(tile3, input_data=[input_data, input_data2])
12881417

12891418

1419+
@tvm.testing.uses_gpu
1420+
def test_forward_zeros():
1421+
@paddle.jit.to_static
1422+
def zeros1(inputs):
1423+
zeros = paddle.zeros([1, 3, 10, 10])
1424+
out = inputs + zeros
1425+
return out
1426+
1427+
@paddle.jit.to_static
1428+
def zeros2(inputs):
1429+
shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32")
1430+
zeros = paddle.zeros(shape)
1431+
out = inputs + zeros
1432+
return out
1433+
1434+
input_shape = [1, 3, 10, 10]
1435+
input_data = paddle.rand(input_shape, dtype="float32")
1436+
verify_model(zeros1, input_data=input_data)
1437+
verify_model(zeros2, input_data=input_data)
1438+
1439+
12901440
if __name__ == "__main__":
12911441
test_forward_add_subtract()
12921442
test_forward_addmm()
@@ -1306,12 +1456,14 @@ def tile3(inputs, inputs2):
13061456
test_forward_expand()
13071457
test_forward_flatten()
13081458
test_forward_shape_full()
1459+
test_forward_ones()
13091460
test_forward_ones_like()
13101461
test_forward_gather_assign_value()
13111462
test_forward_gather_nd()
13121463
test_forward_gelu()
13131464
test_forward_hard_sigmoid()
13141465
test_forward_hard_swish()
1466+
test_forward_index_select()
13151467
test_forward_interpolate()
13161468
test_forward_isinf()
13171469
test_forward_layer_norm()
@@ -1320,6 +1472,8 @@ def tile3(inputs, inputs2):
13201472
test_forward_lstm()
13211473
test_forward_matmul()
13221474
test_forward_multiply()
1475+
test_forward_not_equal()
1476+
test_forward_norm()
13231477
test_forward_pool2d()
13241478
test_forward_pad()
13251479
test_forward_pow()
@@ -1333,3 +1487,4 @@ def tile3(inputs, inputs2):
13331487
test_forward_tile()
13341488
test_forward_conv_transpose()
13351489
test_forward_unary_op()
1490+
test_forward_zeros()

0 commit comments

Comments
 (0)