-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exponential added. #138
Exponential added. #138
Changes from all commits
3d1da5d
b8cd3be
e78654b
07afd22
c513584
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float | ||
|
||
|
||
def heur_block(args): | ||
if args["N"] <= 512: | ||
return 512 | ||
else: | ||
return 1024 | ||
|
||
|
||
def heur_num_warps(args): | ||
if args["N"] <= 512: | ||
return 4 | ||
elif args["N"] <= 1024: | ||
return 8 | ||
else: | ||
return 16 | ||
|
||
|
||
@triton.heuristics( | ||
{ | ||
"BLOCK": heur_block, | ||
"num_warps": heur_num_warps, | ||
} | ||
) | ||
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) | ||
def fused_exponential_kernel( | ||
out_ptr, | ||
N, | ||
is_double, | ||
lambd, | ||
eps, | ||
philox_seed, | ||
philox_offset, | ||
BLOCK: tl.constexpr, | ||
): | ||
philox_seed = philox_seed.to(tl.int64) | ||
philox_offset = philox_offset.to(tl.int64) | ||
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) | ||
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) | ||
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) | ||
c0 += i4 | ||
_O = c0 * 0 | ||
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) | ||
if is_double: | ||
d0 = uint_to_uniform_float(paste_u64(r0, r2)) | ||
d1 = uint_to_uniform_float(paste_u64(r1, r3)) | ||
y0 = transform_exponential(d0, lambd, eps) | ||
y1 = transform_exponential(d1, lambd, eps) | ||
UNROLL = 2 | ||
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL | ||
off_0 = start + tl.arange(0, BLOCK) | ||
off_1 = off_0 + BLOCK | ||
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") | ||
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") | ||
else: | ||
f0 = uint_to_uniform_float(r0) | ||
f1 = uint_to_uniform_float(r1) | ||
f2 = uint_to_uniform_float(r2) | ||
f3 = uint_to_uniform_float(r3) | ||
y0 = transform_exponential(f0, lambd, eps) | ||
y1 = transform_exponential(f1, lambd, eps) | ||
y2 = transform_exponential(f2, lambd, eps) | ||
y3 = transform_exponential(f3, lambd, eps) | ||
UNROLL = 4 | ||
start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL | ||
off_0 = start + tl.arange(0, BLOCK) | ||
off_1 = off_0 + BLOCK | ||
off_2 = off_1 + BLOCK | ||
off_3 = off_2 + BLOCK | ||
tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") | ||
tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") | ||
tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") | ||
tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") | ||
|
||
|
||
@triton.jit | ||
def paste_u64(hi: tl.uint32, lo: tl.uint32): | ||
hi = hi.to(tl.uint64) << 32 | ||
x = hi | lo.to(tl.uint64) | ||
return x | ||
|
||
|
||
@triton.jit | ||
def transform_exponential(u, lambd, eps): | ||
eps1 = -0.5 * eps | ||
is_min = u >= 1.0 + eps1 | ||
log = tl.where(is_min, eps1, tl.math.log(u)) | ||
v = -1.0 / lambd * log | ||
return v | ||
tongxin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def exponential_(x, lambd: float = 1.0, *, gen=None): | ||
logging.debug("GEMS EXPONENTIAL_") | ||
dtype = x.dtype | ||
device = x.device | ||
inplace = x.is_contiguous() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performing inplace operation on a tensor with internal overlapping should raise a Runtime Exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now it would raise a runtime error when copying data back. import torch
import flag_gems
flag_gems.enable()
x = torch.ones(2, device="cuda")
x = torch.broadcast_to(x, (3, 2))
x.exponential_() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll fix it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pytorch throws with exactly the same error. We'll just keep the current way. |
||
assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) | ||
is_double = dtype in (torch.float64,) | ||
UNROLL = 2 if is_double else 4 | ||
N = x.numel() | ||
grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) | ||
# (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, | ||
# hence we cannot obtain the per thread offset as in Pytorch. | ||
increment = triton.cdiv(N, UNROLL) | ||
philox_seed, philox_offset = philox_cuda_seed_offset(increment) | ||
eps = torch.finfo(dtype).eps | ||
x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) | ||
with torch.cuda.device(device): | ||
fused_exponential_kernel[grid_fn]( | ||
x_, N, is_double, lambd, eps, philox_seed, philox_offset | ||
) | ||
if not inplace: | ||
x.copy_(x_) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed? What about just using
log(u)
orlog(1-u)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for enforcing compatibility with Pytorch..