Skip to content

Commit

Permalink
Merge pull request apache#25 from wjj19950828/paddle_frontend
Browse files Browse the repository at this point in the history
Paddle frontend
  • Loading branch information
jiangjiajun authored Sep 9, 2021
2 parents 7ef3b0b + 33b050b commit 986914f
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 26 deletions.
86 changes: 85 additions & 1 deletion python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ def convert_elementwise_op(g, op, block):
"elementwise_add": lambda x, y: x + y,
"elementwise_mul": lambda x, y: x * y,
"elementwise_sub": lambda x, y: x - y,
"elementwise_mod": lambda x, y: x % y,
"elementwise_mod": _op.mod,
"elementwise_pow": _op.power,
"elementwise_floordiv": _op.floor_divide
}
op_func = op_map[op.type]
ipt0 = g.get_node(op.input("X")[0])
Expand Down Expand Up @@ -744,6 +746,15 @@ def convert_leaky_relu(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_less_than(g, op, block):
"""Operator converter for less_than."""

x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.less(x, y)
g.add_node(op.output("Out")[0], out)


def convert_lookup_table(g, op, block):
"""Operator converter for lookup_table_v2."""

Expand All @@ -767,6 +778,20 @@ def convert_log1p(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_logsumexp(g, op, block):
"""Operator converter for logsumexp."""

input_x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
if op.attr("reduce_all"):
axis = None
keepdims = op.attr("keepdim")
out = get_relay_op("logsumexp")(input_x, axis=axis, keepdims=keepdims)
if not axis and not keepdims:
out = _op.expand_dims(out, axis=0)
g.add_node(op.output("Out")[0], out)


def convert_matmul(g, op, block):
"""Operator converter for matmul."""

Expand Down Expand Up @@ -921,6 +946,15 @@ def convert_mul(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_not_equal(g, op, block):
"""Operator converter for not_equal."""

x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.not_equal(x, y)
g.add_node(op.output("Out")[0], out)


def convert_pool2d(g, op, block):
"""Operator converter for pool2d."""

Expand Down Expand Up @@ -1007,6 +1041,17 @@ def convert_padding(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_pow(g, op, block):
"""Operator converter for pow."""

x = g.get_node(op.input("X")[0])
factor = op.attr("factor")
factor = _expr.const(factor, dtype="float32").astype("float32")

out = _op.power(x, factor)
g.add_node(op.output("Out")[0], out)


def convert_range(g, op, block):
"""Operator converter for range."""

Expand All @@ -1033,6 +1078,12 @@ def convert_reduce(g, op, block):

op_map = {
"reduce_all": "all",
"reduce_any": "any",
"reduce_max": "max",
"reduce_min": "min",
"reduce_prod": "prod",
"reduce_sum": "sum",
"reduce_mean": "mean",
}
op_name = op_map[op.type]
input_x = g.get_node(op.input("X")[0])
Expand Down Expand Up @@ -1307,6 +1358,25 @@ def convert_squeeze(g, op, block):
g.add_node(op.output("Out")[0], x)


def convert_topk(g, op, block):
"""Operator converter for topk."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
largest = op.attr("largest")
is_ascend = not bool(largest)
k_node = op.input("K")
if k_node:
k_node = g.get_node(k_node[0])
k = _infer_value(k_node, g.get_params())
else:
k = op.attr("k")
outs = _op.topk(x, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int32")

g.add_node(op.output("Out")[0], outs[0])
g.add_node(op.output("Indices")[0], outs[1])


def convert_stack(g, op, block):
"""Operator converter for stack."""

Expand Down Expand Up @@ -1392,6 +1462,9 @@ def convert_unsqueeze(g, op, block):
"elementwise_div": convert_elementwise_op,
"elementwise_mul": convert_elementwise_op,
"elementwise_sub": convert_elementwise_op,
"elementwise_mod": convert_elementwise_op,
"elementwise_pow": convert_elementwise_op,
"elementwise_floordiv": convert_elementwise_op,
"equal": convert_equal,
"exp": convert_unary_op,
"expand_v2": convert_expand,
Expand All @@ -1410,21 +1483,31 @@ def convert_unsqueeze(g, op, block):
"isinf_v2": convert_unary_op,
"layer_norm": convert_layer_norm,
"leaky_relu": convert_leaky_relu,
"less_than": convert_less_than,
"lookup_table": convert_lookup_table,
"lookup_table_v2": convert_lookup_table,
"log": convert_unary_op,
"log10": convert_unary_op,
"log1p": convert_log1p,
"logsumexp": convert_logsumexp,
"matmul": convert_matmul,
"matmul_v2": convert_matmul,
"mul": convert_mul,
"nearest_interp_v2": convert_interpolate,
"not_equal": convert_not_equal,
"pool2d": convert_pool2d,
"pad1d": convert_padding,
"pad2d": convert_padding,
"pad3d": convert_padding,
"pow": convert_pow,
"range": convert_range,
"reduce_all": convert_reduce,
"reduce_any": convert_reduce,
"reduce_max": convert_reduce,
"reduce_min": convert_reduce,
"reduce_prod": convert_reduce,
"reduce_sum": convert_reduce,
"reduce_mean": convert_reduce,
"relu": convert_unary_op,
"reshape2": convert_reshape,
"rnn": convert_rnn,
Expand All @@ -1438,6 +1521,7 @@ def convert_unsqueeze(g, op, block):
"stack": convert_stack,
"tan": convert_unary_op,
"tanh": convert_unary_op,
"top_k_v2": convert_topk,
"tile": convert_tile,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
Expand Down
Loading

0 comments on commit 986914f

Please sign in to comment.