Skip to content

Commit

Permalink
[FastMath] Add fast_softmax support in fast_math pass (apache#8138)
Browse files Browse the repository at this point in the history
* Add fast_softmax support in fast_math pass

* Lintfix

* Update
  • Loading branch information
jcf94 authored and Trevor Morris committed Jun 17, 2021
1 parent 10a3f25 commit 371c2f8
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def fast_softmax_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.fast_softmax),
naive_schedule,
wrap_topi_schedule(topi.generic.schedule_fast_softmax),
name="fast_softmax.generic",
)
return strategy
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,23 @@ def schedule_softmax(outs):
return _default_schedule(outs, False)


def schedule_fast_softmax(outs):
"""Schedule for fast_softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of fast_softmax
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_dense(outs):
"""Schedule for dense
Expand Down
9 changes: 8 additions & 1 deletion src/relay/transforms/fast_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ namespace relay {

class FastMathMutator : public ExprRewriter {
public:
FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {}
FastMathMutator()
: exp_op_(Op::Get("exp")),
erf_op_(Op::Get("erf")),
tanh_op_(Op::Get("tanh")),
softmax_op_(Op::Get("nn.softmax")) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == exp_op_) {
Expand All @@ -43,6 +47,8 @@ class FastMathMutator : public ExprRewriter {
return FastErf(post.as<CallNode>()->args[0]);
} else if (pre->op == tanh_op_) {
return FastTanh(post.as<CallNode>()->args[0]);
} else if (pre->op == softmax_op_) {
return FastSoftmax(post.as<CallNode>()->args[0], post.as<CallNode>()->attrs);
}
return post;
}
Expand All @@ -54,6 +60,7 @@ class FastMathMutator : public ExprRewriter {
const Op& exp_op_;
const Op& erf_op_;
const Op& tanh_op_;
const Op& softmax_op_;
};

Expr FastMath(const Expr& e) {
Expand Down
5 changes: 5 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,11 @@ inline Expr FastTanh(Expr e) {
return Call(op, {e});
}

inline Expr FastSoftmax(Expr e, tvm::Attrs attr) {
static const Op& op = Op::Get("nn.fast_softmax");
return Call(op, {e}, attr);
}

inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return Call(op, {e});
Expand Down
10 changes: 9 additions & 1 deletion tests/python/relay/test_op_fast_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def test_fastmath():
def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
a_np = np.arange(low, high, step).astype(dtype)
a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
b_np = f_numpy(a_np)

x = relay.var("x", shape=a_np.shape, dtype="float32")
Expand Down Expand Up @@ -56,6 +56,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
test_apply(
relay.nn.fast_softmax,
"nn_fast_softmax",
tvm.topi.testing.softmax_python,
low=-10,
high=10,
step=0.01,
)


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_pass_fast_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,19 @@ def test_erf():
assert "fast_erf" in fast_mod[0].astext()


def test_softmax():
x = relay.var("x", shape=(1, 16), dtype="float32")
y = relay.nn.softmax(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
fast_mod = relay.optimize(mod, target="llvm")
assert "nn.fast_softmax" in fast_mod[0].astext()


if __name__ == "__main__":
test_exp()
test_tanh()
test_erf()
test_softmax()

0 comments on commit 371c2f8

Please sign in to comment.