Skip to content

Commit

Permalink
Multinomial (#141)
Browse files Browse the repository at this point in the history
* WIP: multinomial

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* Adding multinomial.

* fixed one off error in binary search

* Added multinomial tests without replacement.

* PR comment

* split test_special_ops

* updated with_replacement  tests

* add K-S test

* split special perf

* Update to a more reliable without-replacement test

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Added multinomial perf tests.

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* resolved conflicts with master.

* fixing multinomial, working in progress.

* Multinomial passes tests.

* Enhance multinomial tests and benchmarks.

* [bugfix] keepdim when samples one

* [bugfix] fix accu test

* fix anomaly behavior in fused_renorm_cumsum

* Polish multinomial tests.

* remove garbage files.

* bfloat16 added for multinomial, polish without replacement test.

* Enable two-pass normed cumsum.

* cumsum updated

* normed cumsum complete.

* Fixed multinomial binary search boundary bug

* fix normed_cumsum bugs.

* quick fix dim check.

---------

Co-authored-by: Bowen12992 <zhangbluestars@gmail.com>
Co-authored-by: Clement Chan <iclementine@outlook.com>
Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com>
Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
Co-authored-by: StrongSpoon <strongspoon@outlook.com>
  • Loading branch information
6 people authored Sep 2, 2024
1 parent 1c7899d commit 2f191fe
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 1 deletion.
17 changes: 17 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def unique_kwargs(dtype, batch, size):
bench.run()


def test_multinomial_with_replacement():
def multinomial_args(dtype, batch, size):
dist = torch.rand((batch, size), dtype=dtype, device="cuda")
n_samples = 10000
return (dist, n_samples, True)

bench = Benchmark(
op_name="multinomial",
torch_op=torch.multinomial,
arg_func=multinomial_args,
dtypes=(torch.float16, torch.float32),
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_pad():
def padding_kwargs(dtype, batch, size):
input = torch.randn((batch, size), device="cuda", dtype=dtype)
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def enable(lib=aten_lib):
lib.impl("mean.dim", mean_dim, "CUDA")
lib.impl("mm", mm, "CUDA")
lib.impl("mul.Tensor", mul, "CUDA")
lib.impl("multinomial", multinomial, "CUDA")
lib.impl("mv", mv, "CUDA")
lib.impl("ne.Tensor", ne, "CUDA")
lib.impl("ne.Scalar", ne_scalar, "CUDA")
Expand Down
5 changes: 4 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .clamp import clamp, clamp_tensor
from .cos import cos
from .cross_entropy_loss import cross_entropy_loss
from .cumsum import cumsum
from .cumsum import cumsum, normed_cumsum
from .div import div_mode, floor_divide, true_divide
from .dropout import native_dropout
from .embedding import embedding
Expand Down Expand Up @@ -48,6 +48,7 @@
from .minimum import minimum
from .mm import mm
from .mul import mul
from .multinomial import multinomial
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
Expand Down Expand Up @@ -115,6 +116,7 @@
"cos",
"pad",
"cumsum",
"normed_cumsum",
"true_divide",
"div_mode",
"floor_divide",
Expand Down Expand Up @@ -153,6 +155,7 @@
"mean_dim",
"mm",
"mul",
"multinomial",
"maximum",
"minimum",
"rand",
Expand Down
257 changes: 257 additions & 0 deletions src/flag_gems/ops/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,260 @@ def cumsum(inp, dim=1, *, dtype=None):
with torch.cuda.device(inp.device):
cumsum_kernel[grid](inp, out, M, N, K)
return out


@libentry()
@triton.jit(do_not_specialize=["K"])
def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
row_start = tl.program_id(0) * K
row_off = tl.arange(0, BLOCK)
x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
if x.dtype.is_fp16():
x = x.to(tl.float32)
y_sum = tl.sum(x, 0)
y = tl.cumsum(x, 0)
y = y / y_sum
tl.store(out + row_start + row_off, y, mask=row_off < K)


@libentry()
@triton.jit(
do_not_specialize=[
"r",
"t",
"R",
"K",
"r_stride",
"out_r_stride",
]
)
def block_cumsum_kernel(
inp,
out,
sums,
r,
t,
R,
K,
r_stride,
k_stride,
out_r_stride,
out_k_stride,
OUTPUT_SUMS: tl.constexpr,
NORMALIZE: tl.constexpr,
HAS_OUT_LAYOUT: tl.constexpr,
TILE: tl.constexpr,
):
# One CTA processes a (r, t*tile) chunk
# rows = [ grid.y, grid.y + r )
# cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
gridx = tl.program_id(0).to(tl.int64)
gridy = tl.program_id(1).to(tl.int64)
n_chunks = tl.num_programs(0)

for row in range(gridy * r, min((gridy + 1) * r, R)):
curr_cumsum = tl.zeros((1,), tl.float32)
row_offset = row * r_stride
cols = gridx * t * TILE + tl.arange(0, TILE)
for ti in range(0, t):
cols_offset = cols * k_stride
x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
if x.dtype.is_fp16() | x.dtype.is_bf16():
x = x.to(tl.float32)
tile_sum = tl.sum(x, 0)[None]
tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
curr_cumsum += tile_sum
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
if OUTPUT_SUMS:
tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
cols += TILE
if NORMALIZE:
cols = gridx * t * TILE + tl.arange(0, TILE)
for _ in range(0, t):
cols_offset = cols * k_stride
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
if x.dtype.is_fp16() | x.dtype.is_bf16():
x = x.to(tl.float32)
x = x / curr_cumsum
tl.store(out + row_offset + cols_offset, x, mask=cols < K)
cols += TILE


@libentry()
@triton.jit(
do_not_specialize=[
"r",
"t",
"R",
"K",
"r_stride",
"out_r_stride",
]
)
def block_update_kernel(
inp,
base,
rscale_ptr,
out,
r,
t,
R,
K,
r_stride,
k_stride,
out_r_stride,
out_k_stride,
rscale_stride,
HAS_OUT_LAYOUT: tl.constexpr,
TILE: tl.constexpr,
):
# One CTA processes a (r, t*tile) chunk
# rows = [ grid.y, grid.y + r )
# cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
gridx = tl.program_id(0).to(tl.int64)
gridy = tl.program_id(1).to(tl.int64)
n_gridx = tl.num_programs(1)

base += gridy * n_gridx + gridx
rscale_ptr += gridy * rscale_stride

for row in range(gridy, min(gridy + r, R)):
d = tl.load(base)
rscale = tl.load(rscale_ptr)
base += gridx
rscale_ptr += rscale_stride
row_offset = row * r_stride
cols = gridx * t * TILE + tl.arange(0, TILE)
for _ in range(0, t):
cols_offset = cols * k_stride
x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
x += d
x /= rscale
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
tl.store(out + row_offset + cols_offset, x, mask=cols < K)
cols += TILE


GRID_Y_LIMIT = 65535


def normed_cumsum(inp, dim=-1):
logging.debug("GEMS NORMED_CUMSUM")
assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
dim = dim % inp.ndim
N = inp.numel()
K = inp.size(dim)
# inp = inp.contiguous()
# First and last dims are easier to handle, but transpose the middle dim to the last
ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
if is_mid_dim:
inp = inp.transpose(dim, -1).contiguous()
dim = -1
out = torch.empty_like(inp)
with torch.cuda.device(inp.device.index):
# Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE = 2048
# Each row is split into n_chunks of chunks where each chunk is compised of
# n_tiles of tiles. Different chunks are assigned to different ctas.
n_rows = N // K
n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
k_stride = inp.stride(dim)
r_stride = inp.size(dim) if k_stride == 1 else 1
if n_rows > GRID_Y_LIMIT:
batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
n_batch = triton.cdiv(n_rows, batch)
else:
batch = 1
n_batch = n_rows

grid = (n_chunks, n_batch)
if n_chunks == 1:
block_cumsum_kernel[grid](
inp,
out,
0,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
OUTPUT_SUMS=False,
NORMALIZE=True,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
return out

if inp.dtype != torch.float64:
acc_dtype = torch.float32
sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device="cuda")
cumsums = torch.empty_like(sums)
block_cumsum_kernel[grid](
inp,
out,
sums,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
OUTPUT_SUMS=True,
NORMALIZE=False,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
# Pass two, scan partial cumsums
block_cumsum_kernel[(1, n_batch)](
sums,
cumsums,
0,
batch,
1,
n_rows,
n_chunks,
n_chunks,
1,
n_chunks,
1,
OUTPUT_SUMS=False,
NORMALIZE=False,
HAS_OUT_LAYOUT=True,
TILE=TILE,
)
# print(sums)
rscale = cumsums[..., -1]
block_update_kernel[grid](
out,
cumsums - sums,
rscale,
out,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
n_chunks,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
return out
Loading

0 comments on commit 2f191fe

Please sign in to comment.