Skip to content

Commit

Permalink
[FRONTEND] use unsigned integers to simplify RNG (triton-lang#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
madeleineth authored Jan 6, 2022
1 parent 001fb75 commit 120cda0
Showing 1 changed file with 28 additions and 33 deletions.
61 changes: 28 additions & 33 deletions python/triton/language/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,16 @@
from . import core as tl


# Notes
# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations.
# 2. multiply_low_high is currently inefficient.
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float

PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox

# -------------------
# randint
# -------------------

@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)

@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
Expand All @@ -40,12 +32,13 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3


@triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
Expand All @@ -55,23 +48,23 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret


@triton.jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
This is the maximally efficient entry point
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
This is the maximally efficient entry point
to Triton's Philox pseudo-random number generator.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
seed = seed + 0
seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32)
z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)


Expand All @@ -82,18 +75,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@triton.jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
x = tl.where(x < 0, -x - 1, x)
return x * max
two_to_the_minus_32 = 2.328306e-10
return x * two_to_the_minus_32


@triton.jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)`
:param seed: The seed for generating random numbers.
Expand All @@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source)


@triton.jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Expand All @@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# randn
# -------------------


@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
Expand All @@ -130,10 +123,11 @@ def pair_uniform_to_normal(u1, u2):
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)


@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
Expand All @@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
n1, _ = pair_uniform_to_normal(u1, u2)
return n1


@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Expand Down

0 comments on commit 120cda0

Please sign in to comment.