Skip to content

Commit

Permalink
Exponential added. (FlagOpen#138)
Browse files Browse the repository at this point in the history
* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.
  • Loading branch information
tongxin authored and Bowen12992 committed Aug 6, 2024
1 parent c8fc4eb commit d5fb5a7
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 2 deletions.
12 changes: 12 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_perf_rand_like():
bench.run()


def test_perf_exponential_():
bench = Benchmark(
op_name="exponential_",
torch_op=torch.Tensor.exponential_,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies = [
[project.optional-dependencies]
test = [
"pytest>=7.1.0",
"numpy>=1.26",
"scipy>=1.14"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def enable(lib=aten_lib):
lib.impl("eq.Tensor", eq, "CUDA")
lib.impl("eq.Scalar", eq_scalar, "CUDA")
lib.impl("exp", exp, "CUDA")
lib.impl("exponential_", exponential_, "CUDA")
lib.impl("ge.Tensor", ge, "CUDA")
lib.impl("ge.Scalar", ge_scalar, "CUDA")
lib.impl("gelu", gelu, "CUDA")
Expand All @@ -49,7 +50,6 @@ def enable(lib=aten_lib):
lib.impl("rand", rand, "CUDA")
lib.impl("randn", randn, "CUDA")
lib.impl("rand_like", rand_like, "CUDA")

lib.impl("mean", mean, "CUDA")
lib.impl("mean.dim", mean_dim, "CUDA")
lib.impl("mm", mm, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .eq import eq, eq_scalar
from .erf import erf
from .exp import exp
from .exponential_ import exponential_
from .flip import flip
from .ge import ge, ge_scalar
from .gelu import gelu
Expand Down Expand Up @@ -96,6 +97,7 @@
"eq",
"eq_scalar",
"exp",
"exponential_",
"flip",
"ge",
"ge_scalar",
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def dropout_backward_kernel(
philox_offset,
BLOCK: tl.constexpr,
):
UNROLL: tl.constexpr = 4
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
Expand Down
121 changes: 121 additions & 0 deletions src/flag_gems/ops/exponential_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float


def heur_block(args):
if args["N"] <= 512:
return 512
else:
return 1024


def heur_num_warps(args):
if args["N"] <= 512:
return 4
elif args["N"] <= 1024:
return 8
else:
return 16


@triton.heuristics(
{
"BLOCK": heur_block,
"num_warps": heur_num_warps,
}
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel(
out_ptr,
N,
is_double,
lambd,
eps,
philox_seed,
philox_offset,
BLOCK: tl.constexpr,
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
if is_double:
d0 = uint_to_uniform_float(paste_u64(r0, r2))
d1 = uint_to_uniform_float(paste_u64(r1, r3))
y0 = transform_exponential(d0, lambd, eps)
y1 = transform_exponential(d1, lambd, eps)
UNROLL = 2
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
off_0 = start + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
else:
f0 = uint_to_uniform_float(r0)
f1 = uint_to_uniform_float(r1)
f2 = uint_to_uniform_float(r2)
f3 = uint_to_uniform_float(r3)
y0 = transform_exponential(f0, lambd, eps)
y1 = transform_exponential(f1, lambd, eps)
y2 = transform_exponential(f2, lambd, eps)
y3 = transform_exponential(f3, lambd, eps)
UNROLL = 4
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
off_0 = start + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")


@triton.jit
def paste_u64(hi: tl.uint32, lo: tl.uint32):
hi = hi.to(tl.uint64) << 32
x = hi | lo.to(tl.uint64)
return x


@triton.jit
def transform_exponential(u, lambd, eps):
eps1 = -0.5 * eps
is_min = u >= 1.0 + eps1
log = tl.where(is_min, eps1, tl.math.log(u))
v = -1.0 / lambd * log
return v


def exponential_(x, lambd: float = 1.0, *, gen=None):
logging.debug("GEMS EXPONENTIAL_")
dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
is_double = dtype in (torch.float64,)
UNROLL = 2 if is_double else 4
N = x.numel()
grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
# (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
# hence we cannot obtain the per thread offset as in Pytorch.
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_cuda_seed_offset(increment)
eps = torch.finfo(dtype).eps
x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
with torch.cuda.device(device):
fused_exponential_kernel[grid_fn](
x_, N, is_double, lambd, eps, philox_seed, philox_offset
)
if not inplace:
x.copy_(x_)
return x
23 changes: 23 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import numpy as np
import pytest
import scipy
import torch

import flag_gems
Expand Down Expand Up @@ -222,3 +224,24 @@ def test_accuracy_rand_like(shape, dtype):
res_out = torch.rand_like(x)
assert (res_out <= 1.0).all()
assert (res_out >= 0.0).all()


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_exponential_(shape, dtype):
x = torch.empty(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
x.exponential_()
assert x.min() > 0


@pytest.mark.parametrize("shape", POINTWISE_SHAPES[:1])
@pytest.mark.parametrize("dtype", (torch.float32,))
@pytest.mark.parametrize("lambd", (0.01, 0.5, 100.0))
def test_accuracy_exponential_pvalue(shape, dtype, lambd):
x = torch.empty(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
x.exponential_(lambd=lambd)
expo_cdf = lambda x: np.where(x < 0, 0, 1.0 - np.exp(-lambd * x))
pvalue = scipy.stats.kstest(x.cpu().numpy().flatten(), expo_cdf).pvalue
assert pvalue > 0.05

0 comments on commit d5fb5a7

Please sign in to comment.