diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index b3352f5e..0bb37b4f 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -98,6 +98,23 @@ def unique_kwargs(dtype, batch, size): bench.run() +def test_multinomial_with_replacement(): + def multinomial_args(dtype, batch, size): + dist = torch.rand((batch, size), dtype=dtype, device="cuda") + n_samples = 10000 + return (dist, n_samples, True) + + bench = Benchmark( + op_name="multinomial", + torch_op=torch.multinomial, + arg_func=multinomial_args, + dtypes=(torch.float16, torch.float32), + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_pad(): def padding_kwargs(dtype, batch, size): input = torch.randn((batch, size), device="cuda", dtype=dtype) diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 0babf7e4..3368802a 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -82,6 +82,7 @@ def enable(lib=aten_lib): lib.impl("mean.dim", mean_dim, "CUDA") lib.impl("mm", mm, "CUDA") lib.impl("mul.Tensor", mul, "CUDA") + lib.impl("multinomial", multinomial, "CUDA") lib.impl("mv", mv, "CUDA") lib.impl("ne.Tensor", ne, "CUDA") lib.impl("ne.Scalar", ne_scalar, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 34a82f24..a1d2c8e1 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -16,7 +16,7 @@ from .clamp import clamp, clamp_tensor from .cos import cos from .cross_entropy_loss import cross_entropy_loss -from .cumsum import cumsum +from .cumsum import cumsum, normed_cumsum from .div import div_mode, floor_divide, true_divide from .dropout import native_dropout from .embedding import embedding @@ -48,6 +48,7 @@ from .minimum import minimum from .mm import mm from .mul import mul +from .multinomial import multinomial from .mv import mv from .ne import ne, ne_scalar from .neg import neg @@ -115,6 +116,7 @@ "cos", "pad", "cumsum", + "normed_cumsum", "true_divide", "div_mode", "floor_divide", @@ -153,6 +155,7 @@ "mean_dim", "mm", "mul", + "multinomial", "maximum", "minimum", "rand", diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 73ae4f06..57cc3fd5 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -76,3 +76,260 @@ def cumsum(inp, dim=1, *, dtype=None): with torch.cuda.device(inp.device): cumsum_kernel[grid](inp, out, M, N, K) return out + + +@libentry() +@triton.jit(do_not_specialize=["K"]) +def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): + row_start = tl.program_id(0) * K + row_off = tl.arange(0, BLOCK) + x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) + if x.dtype.is_fp16(): + x = x.to(tl.float32) + y_sum = tl.sum(x, 0) + y = tl.cumsum(x, 0) + y = y / y_sum + tl.store(out + row_start + row_off, y, mask=row_off < K) + + +@libentry() +@triton.jit( + do_not_specialize=[ + "r", + "t", + "R", + "K", + "r_stride", + "out_r_stride", + ] +) +def block_cumsum_kernel( + inp, + out, + sums, + r, + t, + R, + K, + r_stride, + k_stride, + out_r_stride, + out_k_stride, + OUTPUT_SUMS: tl.constexpr, + NORMALIZE: tl.constexpr, + HAS_OUT_LAYOUT: tl.constexpr, + TILE: tl.constexpr, +): + # One CTA processes a (r, t*tile) chunk + # rows = [ grid.y, grid.y + r ) + # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) + gridx = tl.program_id(0).to(tl.int64) + gridy = tl.program_id(1).to(tl.int64) + n_chunks = tl.num_programs(0) + + for row in range(gridy * r, min((gridy + 1) * r, R)): + curr_cumsum = tl.zeros((1,), tl.float32) + row_offset = row * r_stride + cols = gridx * t * TILE + tl.arange(0, TILE) + for ti in range(0, t): + cols_offset = cols * k_stride + x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) + if x.dtype.is_fp16() | x.dtype.is_bf16(): + x = x.to(tl.float32) + tile_sum = tl.sum(x, 0)[None] + tile_cumsum = tl.cumsum(x, 0) + curr_cumsum + curr_cumsum += tile_sum + if HAS_OUT_LAYOUT: + cols_offset = cols * out_k_stride + row_offset = row * out_r_stride + tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) + if OUTPUT_SUMS: + tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) + cols += TILE + if NORMALIZE: + cols = gridx * t * TILE + tl.arange(0, TILE) + for _ in range(0, t): + cols_offset = cols * k_stride + if HAS_OUT_LAYOUT: + cols_offset = cols * out_k_stride + row_offset = row * out_r_stride + x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) + if x.dtype.is_fp16() | x.dtype.is_bf16(): + x = x.to(tl.float32) + x = x / curr_cumsum + tl.store(out + row_offset + cols_offset, x, mask=cols < K) + cols += TILE + + +@libentry() +@triton.jit( + do_not_specialize=[ + "r", + "t", + "R", + "K", + "r_stride", + "out_r_stride", + ] +) +def block_update_kernel( + inp, + base, + rscale_ptr, + out, + r, + t, + R, + K, + r_stride, + k_stride, + out_r_stride, + out_k_stride, + rscale_stride, + HAS_OUT_LAYOUT: tl.constexpr, + TILE: tl.constexpr, +): + # One CTA processes a (r, t*tile) chunk + # rows = [ grid.y, grid.y + r ) + # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) + gridx = tl.program_id(0).to(tl.int64) + gridy = tl.program_id(1).to(tl.int64) + n_gridx = tl.num_programs(1) + + base += gridy * n_gridx + gridx + rscale_ptr += gridy * rscale_stride + + for row in range(gridy, min(gridy + r, R)): + d = tl.load(base) + rscale = tl.load(rscale_ptr) + base += gridx + rscale_ptr += rscale_stride + row_offset = row * r_stride + cols = gridx * t * TILE + tl.arange(0, TILE) + for _ in range(0, t): + cols_offset = cols * k_stride + x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) + x += d + x /= rscale + if HAS_OUT_LAYOUT: + cols_offset = cols * out_k_stride + row_offset = row * out_r_stride + tl.store(out + row_offset + cols_offset, x, mask=cols < K) + cols += TILE + + +GRID_Y_LIMIT = 65535 + + +def normed_cumsum(inp, dim=-1): + logging.debug("GEMS NORMED_CUMSUM") + assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) + dim = dim % inp.ndim + N = inp.numel() + K = inp.size(dim) + # inp = inp.contiguous() + # First and last dims are easier to handle, but transpose the middle dim to the last + ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) + is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) + if is_mid_dim: + inp = inp.transpose(dim, -1).contiguous() + dim = -1 + out = torch.empty_like(inp) + with torch.cuda.device(inp.device.index): + # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + TILE = 2048 + # Each row is split into n_chunks of chunks where each chunk is compised of + # n_tiles of tiles. Different chunks are assigned to different ctas. + n_rows = N // K + n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) + n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) + k_stride = inp.stride(dim) + r_stride = inp.size(dim) if k_stride == 1 else 1 + if n_rows > GRID_Y_LIMIT: + batch = triton.cdiv(n_rows, GRID_Y_LIMIT) + n_batch = triton.cdiv(n_rows, batch) + else: + batch = 1 + n_batch = n_rows + + grid = (n_chunks, n_batch) + if n_chunks == 1: + block_cumsum_kernel[grid]( + inp, + out, + 0, + batch, + n_tiles, + n_rows, + K, + r_stride, + k_stride, + r_stride, + k_stride, + OUTPUT_SUMS=False, + NORMALIZE=True, + HAS_OUT_LAYOUT=False, + TILE=TILE, + ) + return out + + if inp.dtype != torch.float64: + acc_dtype = torch.float32 + sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device="cuda") + cumsums = torch.empty_like(sums) + block_cumsum_kernel[grid]( + inp, + out, + sums, + batch, + n_tiles, + n_rows, + K, + r_stride, + k_stride, + r_stride, + k_stride, + OUTPUT_SUMS=True, + NORMALIZE=False, + HAS_OUT_LAYOUT=False, + TILE=TILE, + ) + # Pass two, scan partial cumsums + block_cumsum_kernel[(1, n_batch)]( + sums, + cumsums, + 0, + batch, + 1, + n_rows, + n_chunks, + n_chunks, + 1, + n_chunks, + 1, + OUTPUT_SUMS=False, + NORMALIZE=False, + HAS_OUT_LAYOUT=True, + TILE=TILE, + ) + # print(sums) + rscale = cumsums[..., -1] + block_update_kernel[grid]( + out, + cumsums - sums, + rscale, + out, + batch, + n_tiles, + n_rows, + K, + r_stride, + k_stride, + r_stride, + k_stride, + n_chunks, + HAS_OUT_LAYOUT=False, + TILE=TILE, + ) + return out diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py new file mode 100644 index 00000000..85412da2 --- /dev/null +++ b/src/flag_gems/ops/multinomial.py @@ -0,0 +1,98 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry +from flag_gems.utils.random_utils import philox_cuda_seed_offset, uniform + + +@libentry() +@triton.heuristics( + { + "NBLOCK": lambda args: 128, + "num_warps": lambda args: 4, + } +) +@triton.jit(do_not_specialize=["K", "N", "philox_seed", "philox_offset"]) +def multinomial_with_replacement( + cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr +): + # The computation is arranged in a 2d grid of blocks, each producing + # a batch of samples for a particular distribution. + # <------------------- grid.x ---------------------> + # | dist0.batch0 | dist0.batch1 | dist0.batch2 ... + # grid.y | dist1.batch0 | dist1.batch1 | dist1.batch2 ... + # | dist2.batch0 | dist2.batch1 | dist2.batch2 ... + y_off = tl.program_id(1) * N + n = tl.program_id(0) * NBLOCK + tl.arange(0, NBLOCK) + rv, _, _, _ = uniform(philox_seed, philox_offset, y_off + n) + + # Do a binary search for each random number on the cumulative probabilities. + # Each random number always selects the leftmost index of the data greater + # than or equal to itself. However, this is likely to give a wrong result + # in case the first probability is zero which is not expected to selected. + # This error happens when the tossed random number is also zero. To avoid + # this mistake, we simply perturb random variable with a small number. + rv += 0.0001 + rv = tl.where(rv > 0.9999, 0.9999, rv) + + cdf_ptr += tl.program_id(1) * K + start = tl.zeros((NBLOCK,), dtype=tl.int32) + end = tl.zeros((NBLOCK,), dtype=tl.int32) + K - 1 + steps = tl.math.log2(K.to(tl.float32)).to(tl.int32) + 1 + for _ in range(steps): + mid = start + (end - start) // 2 + x = tl.load(cdf_ptr + mid, mask=n < N) + start = tl.where(x < rv, mid + 1, start) + end = tl.where(x < rv, end, mid) + + # Returns the last index in case of an overflow + start = tl.where(start >= K, K - 1, start) + + tl.store(out_ptr + y_off + n, start, mask=n < N) + + +def multinomial(prob, n_samples, with_replacement=False, *, gen=None): + logging.debug("GEMS MULTINOMIAL") + assert prob.dtype in (torch.float16, torch.float32, torch.bfloat16, torch.float64) + assert 0 < prob.dim() <= 2, "prob_dist must be 1 or 2 dim" + n_categories = prob.size(-1) + assert n_categories <= (1 << 24), "number of categories cannot exceed 2^24" + assert ( + with_replacement or n_samples <= n_categories + ), "cannot sample n_samples > prob.size(-1) samples without replacement." + + # Sampling without replacement + if (not with_replacement) or n_samples == 1: + # In case of with_replacement, sampling is approximated by selecing + # the top k indices over sorted probabilities with an exponential pertubation + # s = argmax( p / q ) where q ~ Exp(1) + q = torch.empty_like(prob).exponential_(1.0) + s = torch.div(prob, q, out=q) + if n_samples == 1: + return torch.argmax(s, dim=-1, keepdim=True).to(torch.int64) + else: + vals, indices = torch.topk(s, n_samples, dim=-1) + return indices.to(torch.int64) + + from flag_gems.ops import normed_cumsum + + cum_prob = normed_cumsum(prob, dim=-1) + + if cum_prob.dim() == 1: + n_dist = 1 + out = torch.empty((n_samples,), device=prob.device, dtype=torch.int64) + else: + n_dist = cum_prob.size(0) + out = torch.empty((n_dist, n_samples), device=prob.device, dtype=torch.int64) + # The CTA level parallelism is framed in a 2d grid of blocks with grid.y + # indexing into distributions and grid.x output sample batches + increment = n_dist * n_samples + philox_seed, philox_offset = philox_cuda_seed_offset(increment) + grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist) + multinomial_with_replacement[grid]( + cum_prob, out, n_categories, n_samples, philox_seed, philox_offset + ) + return out diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index f2fa2deb..32b511fc 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -50,3 +50,20 @@ def per_thread_offset(N, num_blocks, num_warps, warp_threads=32): max_threads = num_blocks * block_threads offset = (N + max_threads - 1) // max_threads return offset + + +@triton.jit +def uniform(seed, philox_offset, offset): + seed = seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + i4 = offset + c0 += i4 + _O = c0 * 0 + r0, r1, r2, r3 = tl.philox(seed, c0, c1, _O, _O) + r0 = uint_to_uniform_float(r0) + r1 = uint_to_uniform_float(r1) + r2 = uint_to_uniform_float(r2) + r3 = uint_to_uniform_float(r3) + return r0, r1, r2, r3 diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 94420d5d..b8e0e472 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -1,3 +1,5 @@ +import itertools + import torch from .conftest import TO_CPU @@ -13,6 +15,15 @@ torch.bfloat16: 0.016, } +sizes_one = [1] +sizes_pow_2 = [2**d for d in range(4, 11, 2)] +sizes_noalign = [d + 17 for d in sizes_pow_2] +sizes_1d = sizes_one + sizes_pow_2 + sizes_noalign +sizes_2d_nc = [1, 16, 64, 1000] +sizes_2d_nr = [1, 5, 1024] + +UT_SHAPES_1D = list((n,) for n in sizes_1d) +UT_SHAPES_2D = list(itertools.product(sizes_2d_nr, sizes_2d_nc)) POINTWISE_SHAPES = [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)] DISTRIBUTION_SHAPES = [(20, 320, 15)] REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)] diff --git a/tests/test_distribution_ops.py b/tests/test_distribution_ops.py index a3e41740..36281bb5 100644 --- a/tests/test_distribution_ops.py +++ b/tests/test_distribution_ops.py @@ -1,4 +1,6 @@ +import numpy as np import pytest +import scipy import torch import flag_gems @@ -36,3 +38,22 @@ def test_accuracy_exponential_(shape, dtype): with flag_gems.use_gems(): x.exponential_() assert x.min() > 0 + + +@pytest.mark.parametrize("shape", [(1024, 10)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("n_samples", [2048]) +def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples): + # First use multinomial to generate a series of indices, then + # use the index counts as the input probabilities (scaled) + rand_indices = torch.multinomial(torch.rand(shape), n_samples, True).to("cuda") + inp_counts = torch.nn.functional.one_hot(rand_indices).sum(1) + with flag_gems.use_gems(): + out_indices = torch.multinomial(inp_counts.to(dtype=dtype), n_samples, True) + out_counts = torch.nn.functional.one_hot(out_indices).sum(1) + # Do a simple Chi-square test + assert torch.equal(inp_counts.sum(-1), out_counts.sum(-1)) + chi2, pvalue = scipy.stats.chisquare( + out_counts.tolist(), inp_counts.tolist(), axis=-1 + ) + assert np.sum(pvalue < 0.05) / len(pvalue) < 0.1 diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index a133c146..417533ba 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -11,6 +11,8 @@ INT_DTYPES, POINTWISE_SHAPES, RESOLUTION, + UT_SHAPES_1D, + UT_SHAPES_2D, gems_assert_close, gems_assert_equal, to_reference, @@ -346,6 +348,46 @@ def test_accuracy_unique(shape, dtype, sorted, return_inverse, return_counts): gems_assert_equal(res_out, ref_out) +@pytest.mark.parametrize("shape", UT_SHAPES_1D + UT_SHAPES_2D) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("n_samples", [1000]) +def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples): + if shape[-1] == 1: + dist = torch.rand(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.multinomial(dist, n_samples, True) + assert torch.all(res_out == 0) + else: + # Mask p% off of the categories and test the sampling results fall in the rest + for p in (0.1, 0.5, 0.9): + dist = torch.rand(size=shape, dtype=dtype, device="cuda") + dist[torch.rand(shape) < p] = 0 + # Make sure there's at least one non-zero probability + dist[..., -1] = 0.5 + with flag_gems.use_gems(): + res_out = torch.multinomial(dist, n_samples, True) + # print(dist) + res_dist = torch.gather(dist, -1, res_out) + assert torch.all(res_dist) + + +@pytest.mark.parametrize("pool", UT_SHAPES_2D) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_multinomial_without_replacement(pool, dtype): + dist = torch.rand(size=pool, dtype=dtype, device="cuda") + k = pool[-1] + if k > 1: + ns = [k // 2, k] + else: + ns = [1] + for n in ns: + with flag_gems.use_gems(): + out = torch.multinomial(dist, n, False) + # Verifies uniqueness + idx_cnt = torch.nn.functional.one_hot(out).sum(1) + assert torch.all(idx_cnt <= 1) + + @pytest.mark.parametrize("shape", [[1024, 1024], [64, 64, 64, 64]]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("pad_mode", ["constant", "reflect", "replicate", "circular"])