Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666466396
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Aug 22, 2024
1 parent a485298 commit 2c2dfd5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
68 changes: 44 additions & 24 deletions tensorflow_probability/python/internal/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ============================================================================
"""Random samplers."""

import contextlib
import collections
import hashlib
import warnings

Expand Down Expand Up @@ -49,17 +49,22 @@

SEED_DTYPE = np.uint32 if JAX_MODE else np.int32

_old_salt = False

_OldSaltSeed = collections.namedtuple('_OldSaltSeed', ['seed'])

@contextlib.contextmanager
def enable_old_salt(enable):
global _old_salt
try:
_old_salt = enable
yield
finally:
_old_salt = False

def enable_old_salt_for_seed(seed, enable):
if enable:
return _OldSaltSeed(seed)
else:
return seed


def _get_seed_and_old_salt(seed):
if isinstance(seed, _OldSaltSeed):
return seed.seed, True
else:
return seed, False


def zeros_seed():
Expand Down Expand Up @@ -116,6 +121,7 @@ def sanitize_seed(seed, salt=None, name=None):
seed.
"""
seed, old_salt = _get_seed_and_old_salt(seed)
if callable(seed): # e.g. SeedStream.
seed = seed()
if salt is not None and not isinstance(salt, str):
Expand Down Expand Up @@ -154,18 +160,18 @@ def sanitize_seed(seed, salt=None, name=None):

if salt is not None:
salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16)
if not _old_salt:
if not old_salt:
salt = salt % (2**31 - 1)
seed = fold_in(seed, salt)

if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
# Typed keys are returned as is, otherwise wrap them.
if jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key):
return seed
else:
return jax.random.wrap_key_data(seed)
return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed')
if not jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key):
seed = jax.random.wrap_key_data(seed)
else:
seed = tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed')
return enable_old_salt_for_seed(seed, old_salt)


def get_integer_seed(seed):
Expand All @@ -181,6 +187,7 @@ def get_integer_seed(seed):
if isinstance(seed, six.integer_types):
return seed % (2**31)
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
# maxval is exclusive, so technically this doesn't generate all possible
# non-negative integers, but it's good enough for our purposes.
integer_seed = tf.random.stateless_uniform(
Expand All @@ -194,17 +201,19 @@ def get_integer_seed(seed):

def fold_in(seed, salt):
"""Folds salt into seed to form a new seed."""
seed, old_salt = _get_seed_and_old_salt(seed)
if JAX_MODE:
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
import jax.numpy as jnp # pylint: disable=g-import-not-at-top
return jaxrand.fold_in(
seed = jaxrand.fold_in(
seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE))
if isinstance(salt, (six.integer_types)):
seed = tf.bitwise.bitwise_xor(
seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))
else:
seed = tf.random.experimental.stateless_fold_in(seed, salt)
return seed
if isinstance(salt, (six.integer_types)):
seed = tf.bitwise.bitwise_xor(
seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))
else:
seed = tf.random.experimental.stateless_fold_in(seed, salt)
return enable_old_salt_for_seed(seed, old_salt)


def split_seed(seed, n=2, salt=None, name=None):
Expand Down Expand Up @@ -235,24 +244,29 @@ def split_seed(seed, n=2, salt=None, name=None):
'`n` must be a python `int` or an int Tensor, got {}'.format(repr(n)))
with tf.name_scope(name or 'split_seed'):
seed = sanitize_seed(seed, salt=salt)
seed, old_salt = _get_seed_and_old_salt(seed)
if JAX_MODE:
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
return jaxrand.split(seed, int(n))
seeds = tf.random.stateless_uniform(
[n, 2], seed=seed, minval=None, maxval=None, dtype=SEED_DTYPE)
if isinstance(n, six.integer_types):
seeds = tf.unstack(seeds)
seeds = [enable_old_salt_for_seed(seed, old_salt) for seed in seeds]
else:
seeds = enable_old_salt_for_seed(seeds, old_salt)
return seeds


def clone_seed(seed):
"""Clones a seed so it can be reused without causing a JAX KeyReuseError."""
seed, old_salt = _get_seed_and_old_salt(seed)
if JAX_MODE:
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
if hasattr(jaxrand, 'clone'):
# JAX v0.4.26+
return jaxrand.clone(seed)
return seed
seed = jaxrand.clone(seed)
return enable_old_salt_for_seed(seed, old_salt)


def categorical(
Expand All @@ -264,6 +278,7 @@ def categorical(
"""As `tf.random.categorical`, but handling stateful/stateless `seed`s."""
with tf.name_scope(name or 'categorical'):
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
return tf.random.stateless_categorical(
logits=logits, num_samples=num_samples, seed=seed, dtype=dtype)

Expand All @@ -278,6 +293,7 @@ def gamma(
"""As `tf.random.gamma`, but handling stateful/stateless `seed`s."""
with tf.name_scope(name or 'gamma'):
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
alpha = tf.convert_to_tensor(alpha, dtype=dtype)
beta = None if beta is None else tf.convert_to_tensor(beta, dtype=dtype)
params_shape = ps.shape(alpha)
Expand Down Expand Up @@ -306,6 +322,7 @@ def normal(
shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype)

seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
return tf.random.stateless_normal(
shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype)

Expand All @@ -319,6 +336,7 @@ def poisson(
"""As `tf.random.poisson`, but handling stateful/stateless `seed`s."""
with tf.name_scope(name or 'poisson'):
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
lam_shape = ps.shape(lam)
sample_shape = ps.concat([shape, lam_shape], axis=0)
return tf.random.stateless_poisson(
Expand All @@ -332,6 +350,7 @@ def shuffle(
"""As `tf.random.shuffle`, but handling stateful/stateless `seed`s."""
with tf.name_scope(name or 'shuffle'):
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
sortkey = tf.random.stateless_uniform(shape=[ps.shape(value)[0]], seed=seed)
return tf.gather(value, tf.argsort(sortkey))

Expand All @@ -346,5 +365,6 @@ def uniform(
"""As `tf.random.uniform`, but handling stateful/stateless `seed`s."""
with tf.name_scope(name or 'uniform'):
seed = sanitize_seed(seed)
seed, _ = _get_seed_and_old_salt(seed)
return tf.random.stateless_uniform(
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype)
10 changes: 5 additions & 5 deletions tensorflow_probability/python/internal/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def setUp(self):
def test_old_salt(self):
if not tf1.control_flow_v2_enabled():
self.skipTest('TF2 only.')
with samplers.enable_old_salt(True):
seed = samplers.sanitize_seed(0, salt='nacl')
seed = samplers.sanitize_seed(seed, salt='kcl')
val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed)
self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val))
seed = samplers.sanitize_seed(0, salt='nacl')
seed = samplers.enable_old_salt_for_seed(seed, True)
seed = samplers.sanitize_seed(seed, salt='kcl')
val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed)
self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val))

def test_new_style_jax_keys(self):
if not JAX_MODE:
Expand Down

0 comments on commit 2c2dfd5

Please sign in to comment.