Skip to content

Commit

Permalink
[Operator] Add gelu backward (#159)
Browse files Browse the repository at this point in the history
Co-authored-by: liwei.dai <liwei.dai@iluvatar.com>
  • Loading branch information
BruceDai003 and liwei.dai authored Aug 23, 2024
1 parent adb2094 commit 2c4625e
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 15 deletions.
56 changes: 55 additions & 1 deletion benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,68 @@ def test_perf_ge():
bench.run()


def test_perf_gelu():
def test_perf_gelu_tanh():
def gelu_kwargs(dtype, batch, size):
return {"approximate": "tanh"}

bench = Benchmark(
op_name="gelu",
torch_op=torch.nn.functional.gelu,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=gelu_kwargs,
)
bench.run()


def test_perf_gelu_none():
def gelu_kwargs(dtype, batch, size):
return {"approximate": "none"}

bench = Benchmark(
op_name="gelu",
torch_op=torch.nn.functional.gelu,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=gelu_kwargs,
)
bench.run()


def test_perf_gelu_backward_tanh():
def gelu_kwargs(dtype, batch, size):
return {"approximate": "tanh"}

bench = Benchmark(
op_name="gelu",
torch_op=torch.nn.functional.gelu,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=gelu_kwargs,
is_backward=True,
)
bench.run()


def test_perf_gelu_backward_none():
def gelu_kwargs(dtype, batch, size):
return {"approximate": "none"}

bench = Benchmark(
op_name="gelu",
torch_op=torch.nn.functional.gelu,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=gelu_kwargs,
is_backward=True,
)
bench.run()

Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def enable(lib=aten_lib):
lib.impl("exponential_", exponential_, "CUDA")
lib.impl("ge.Tensor", ge, "CUDA")
lib.impl("ge.Scalar", ge_scalar, "CUDA")
lib.impl("gelu", gelu, "CUDA")
lib.impl("gelu", gelu, "AutogradCUDA")
lib.impl("native_group_norm", group_norm, "AutogradCUDA")
lib.impl("gt.Tensor", gt, "CUDA")
lib.impl("gt.Scalar", gt_scalar, "CUDA")
Expand Down
67 changes: 58 additions & 9 deletions src/flag_gems/ops/gelu.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import pointwise_dynamic

try:
from triton.language.extra.cuda.libdevice import erf, pow, tanh
from triton.language.extra.cuda.libdevice import erf, exp, pow, tanh
except ImportError:
try:
from triton.language.math import erf, pow, tanh
from triton.language.math import erf, exp, pow, tanh
except ImportError:
from triton.language.libdevice import erf, pow, tanh
from triton.language.libdevice import erf, exp, pow, tanh


@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
@triton.jit
def gelu_none(x):
scale: tl.constexpr = 0.7071067811
scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
output = 0.5 * x * (1 + erf(x * scale))
return output

Expand All @@ -31,9 +32,57 @@ def gelu_tanh(x):
return output


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def gelu_backward_none(x, dy):
scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
x_fp32 = x.to(tl.float32)
dydx = (
scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
+ 0.5 * erf(scale1 * x_fp32)
+ 0.5
)
dx = dydx * dy
return dx


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def gelu_backward_tanh(x, dy):
x_fp32 = x.to(tl.float32)
# 0.79788456 = math.sqrt(2 / math.pi)
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2)))
dydx = 0.5 * x * (
(1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
) + 0.5 * (1 + tanh_out)
dx = dydx * dy
return dx


class Gelu(torch.autograd.Function):
@staticmethod
def forward(ctx, A, approximate):
logging.debug("GEMS GELU FORWARD")
if approximate == "tanh":
out = gelu_tanh(A)
else:
out = gelu_none(A)
ctx.save_for_backward(A)
ctx.approximate = approximate
return out

@staticmethod
def backward(ctx, out_grad):
logging.debug("GEMS GELU BACKWARD")
(inp,) = ctx.saved_tensors
approximate = ctx.approximate
if approximate == "tanh":
in_grad = gelu_backward_tanh(inp, out_grad)
else:
in_grad = gelu_backward_none(inp, out_grad)
return in_grad, None


def gelu(A, *, approximate="none"):
logging.debug("GEMS GELU")
if approximate == "tanh":
return gelu_tanh(A)
else:
return gelu_none(A)
return Gelu.apply(A, approximate)
16 changes: 12 additions & 4 deletions tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,24 @@ def test_accuracy_exp(shape, dtype):

@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_gelu(shape, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.parametrize("approximate", ["none", "tanh"])
def test_accuracy_gelu(shape, dtype, approximate):
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp, True)

ref_out = torch.nn.functional.gelu(ref_inp)
ref_out = torch.nn.functional.gelu(ref_inp, approximate=approximate)
with flag_gems.use_gems():
res_out = torch.nn.functional.gelu(inp)
res_out = torch.nn.functional.gelu(inp, approximate=approximate)

gems_assert_close(res_out, ref_out, dtype)

out_grad = torch.randn_like(inp)
ref_grad = to_reference(out_grad, True)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
gems_assert_close(res_in_grad, ref_in_grad, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down

0 comments on commit 2c4625e

Please sign in to comment.