Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponential added. #138

Merged
merged 5 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_perf_rand_like():
bench.run()


def test_perf_exponential_():
bench = Benchmark(
op_name="exponential_",
torch_op=torch.Tensor.exponential_,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies = [
[project.optional-dependencies]
test = [
"pytest>=7.1.0",
"numpy>=1.26",
"scipy>=1.14"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def enable(lib=aten_lib):
lib.impl("eq.Tensor", eq, "CUDA")
lib.impl("eq.Scalar", eq_scalar, "CUDA")
lib.impl("exp", exp, "CUDA")
lib.impl("exponential_", exponential_, "CUDA")
lib.impl("ge.Tensor", ge, "CUDA")
lib.impl("ge.Scalar", ge_scalar, "CUDA")
lib.impl("gelu", gelu, "CUDA")
Expand All @@ -49,7 +50,6 @@ def enable(lib=aten_lib):
lib.impl("rand", rand, "CUDA")
lib.impl("randn", randn, "CUDA")
lib.impl("rand_like", rand_like, "CUDA")

lib.impl("mean", mean, "CUDA")
lib.impl("mean.dim", mean_dim, "CUDA")
lib.impl("mm", mm, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .eq import eq, eq_scalar
from .erf import erf
from .exp import exp
from .exponential_ import exponential_
from .flip import flip
from .ge import ge, ge_scalar
from .gelu import gelu
Expand Down Expand Up @@ -96,6 +97,7 @@
"eq",
"eq_scalar",
"exp",
"exponential_",
"flip",
"ge",
"ge_scalar",
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def dropout_backward_kernel(
philox_offset,
BLOCK: tl.constexpr,
):
UNROLL: tl.constexpr = 4
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
Expand Down
121 changes: 121 additions & 0 deletions src/flag_gems/ops/exponential_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float


def heur_block(args):
if args["N"] <= 512:
return 512
else:
return 1024


def heur_num_warps(args):
if args["N"] <= 512:
return 4
elif args["N"] <= 1024:
return 8
else:
return 16


@triton.heuristics(
{
"BLOCK": heur_block,
"num_warps": heur_num_warps,
}
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel(
out_ptr,
N,
is_double,
lambd,
eps,
philox_seed,
philox_offset,
BLOCK: tl.constexpr,
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
if is_double:
d0 = uint_to_uniform_float(paste_u64(r0, r2))
d1 = uint_to_uniform_float(paste_u64(r1, r3))
y0 = transform_exponential(d0, lambd, eps)
y1 = transform_exponential(d1, lambd, eps)
UNROLL = 2
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
off_0 = start + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
else:
f0 = uint_to_uniform_float(r0)
f1 = uint_to_uniform_float(r1)
f2 = uint_to_uniform_float(r2)
f3 = uint_to_uniform_float(r3)
y0 = transform_exponential(f0, lambd, eps)
y1 = transform_exponential(f1, lambd, eps)
y2 = transform_exponential(f2, lambd, eps)
y3 = transform_exponential(f3, lambd, eps)
UNROLL = 4
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
off_0 = start + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")


@triton.jit
def paste_u64(hi: tl.uint32, lo: tl.uint32):
hi = hi.to(tl.uint64) << 32
x = hi | lo.to(tl.uint64)
return x


@triton.jit
def transform_exponential(u, lambd, eps):
eps1 = -0.5 * eps
is_min = u >= 1.0 + eps1
log = tl.where(is_min, eps1, tl.math.log(u))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? What about just using log(u) or log(1-u)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for enforcing compatibility with Pytorch..

v = -1.0 / lambd * log
return v
tongxin marked this conversation as resolved.
Show resolved Hide resolved


def exponential_(x, lambd: float = 1.0, *, gen=None):
logging.debug("GEMS EXPONENTIAL_")
dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performing inplace operation on a tensor with internal overlapping should raise a Runtime Exception.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.

Copy link
Collaborator

@iclementine iclementine Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it would raise a runtime error when copying data back.

import torch
import flag_gems
flag_gems.enable()
x = torch.ones(2, device="cuda")
x = torch.broadcast_to(x, (3, 2))
x.exponential_()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch throws with exactly the same error. We'll just keep the current way.

assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
is_double = dtype in (torch.float64,)
UNROLL = 2 if is_double else 4
N = x.numel()
grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
# (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
# hence we cannot obtain the per thread offset as in Pytorch.
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_cuda_seed_offset(increment)
eps = torch.finfo(dtype).eps
x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
with torch.cuda.device(device):
fused_exponential_kernel[grid_fn](
x_, N, is_double, lambd, eps, philox_seed, philox_offset
)
if not inplace:
x.copy_(x_)
return x
23 changes: 23 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import numpy as np
import pytest
import scipy
import torch

import flag_gems
Expand Down Expand Up @@ -222,3 +224,24 @@ def test_accuracy_rand_like(shape, dtype):
res_out = torch.rand_like(x)
assert (res_out <= 1.0).all()
assert (res_out >= 0.0).all()


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_exponential_(shape, dtype):
x = torch.empty(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
x.exponential_()
assert x.min() > 0


@pytest.mark.parametrize("shape", POINTWISE_SHAPES[:1])
@pytest.mark.parametrize("dtype", (torch.float32,))
@pytest.mark.parametrize("lambd", (0.01, 0.5, 100.0))
def test_accuracy_exponential_pvalue(shape, dtype, lambd):
x = torch.empty(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
x.exponential_(lambd=lambd)
expo_cdf = lambda x: np.where(x < 0, 0, 1.0 - np.exp(-lambd * x))
pvalue = scipy.stats.kstest(x.cpu().numpy().flatten(), expo_cdf).pvalue
assert pvalue > 0.05