From 2c2dfd54ae15f9c5c094c74e2a023b9696233f65 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 22 Aug 2024 13:23:06 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 666466396 --- .../python/internal/samplers.py | 68 ++++++++++++------- .../python/internal/samplers_test.py | 10 +-- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index f8afd26046..894291fdc9 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -14,7 +14,7 @@ # ============================================================================ """Random samplers.""" -import contextlib +import collections import hashlib import warnings @@ -49,17 +49,22 @@ SEED_DTYPE = np.uint32 if JAX_MODE else np.int32 -_old_salt = False +_OldSaltSeed = collections.namedtuple('_OldSaltSeed', ['seed']) -@contextlib.contextmanager -def enable_old_salt(enable): - global _old_salt - try: - _old_salt = enable - yield - finally: - _old_salt = False + +def enable_old_salt_for_seed(seed, enable): + if enable: + return _OldSaltSeed(seed) + else: + return seed + + +def _get_seed_and_old_salt(seed): + if isinstance(seed, _OldSaltSeed): + return seed.seed, True + else: + return seed, False def zeros_seed(): @@ -116,6 +121,7 @@ def sanitize_seed(seed, salt=None, name=None): seed. """ + seed, old_salt = _get_seed_and_old_salt(seed) if callable(seed): # e.g. SeedStream. seed = seed() if salt is not None and not isinstance(salt, str): @@ -154,18 +160,18 @@ def sanitize_seed(seed, salt=None, name=None): if salt is not None: salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) - if not _old_salt: + if not old_salt: salt = salt % (2**31 - 1) seed = fold_in(seed, salt) if JAX_MODE: import jax # pylint: disable=g-import-not-at-top # Typed keys are returned as is, otherwise wrap them. - if jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key): - return seed - else: - return jax.random.wrap_key_data(seed) - return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed') + if not jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key): + seed = jax.random.wrap_key_data(seed) + else: + seed = tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed') + return enable_old_salt_for_seed(seed, old_salt) def get_integer_seed(seed): @@ -181,6 +187,7 @@ def get_integer_seed(seed): if isinstance(seed, six.integer_types): return seed % (2**31) seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) # maxval is exclusive, so technically this doesn't generate all possible # non-negative integers, but it's good enough for our purposes. integer_seed = tf.random.stateless_uniform( @@ -194,17 +201,19 @@ def get_integer_seed(seed): def fold_in(seed, salt): """Folds salt into seed to form a new seed.""" + seed, old_salt = _get_seed_and_old_salt(seed) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top import jax.numpy as jnp # pylint: disable=g-import-not-at-top - return jaxrand.fold_in( + seed = jaxrand.fold_in( seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE)) - if isinstance(salt, (six.integer_types)): - seed = tf.bitwise.bitwise_xor( - seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) else: - seed = tf.random.experimental.stateless_fold_in(seed, salt) - return seed + if isinstance(salt, (six.integer_types)): + seed = tf.bitwise.bitwise_xor( + seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) + else: + seed = tf.random.experimental.stateless_fold_in(seed, salt) + return enable_old_salt_for_seed(seed, old_salt) def split_seed(seed, n=2, salt=None, name=None): @@ -235,6 +244,7 @@ def split_seed(seed, n=2, salt=None, name=None): '`n` must be a python `int` or an int Tensor, got {}'.format(repr(n))) with tf.name_scope(name or 'split_seed'): seed = sanitize_seed(seed, salt=salt) + seed, old_salt = _get_seed_and_old_salt(seed) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top return jaxrand.split(seed, int(n)) @@ -242,17 +252,21 @@ def split_seed(seed, n=2, salt=None, name=None): [n, 2], seed=seed, minval=None, maxval=None, dtype=SEED_DTYPE) if isinstance(n, six.integer_types): seeds = tf.unstack(seeds) + seeds = [enable_old_salt_for_seed(seed, old_salt) for seed in seeds] + else: + seeds = enable_old_salt_for_seed(seeds, old_salt) return seeds def clone_seed(seed): """Clones a seed so it can be reused without causing a JAX KeyReuseError.""" + seed, old_salt = _get_seed_and_old_salt(seed) if JAX_MODE: from jax import random as jaxrand # pylint: disable=g-import-not-at-top if hasattr(jaxrand, 'clone'): # JAX v0.4.26+ - return jaxrand.clone(seed) - return seed + seed = jaxrand.clone(seed) + return enable_old_salt_for_seed(seed, old_salt) def categorical( @@ -264,6 +278,7 @@ def categorical( """As `tf.random.categorical`, but handling stateful/stateless `seed`s.""" with tf.name_scope(name or 'categorical'): seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) return tf.random.stateless_categorical( logits=logits, num_samples=num_samples, seed=seed, dtype=dtype) @@ -278,6 +293,7 @@ def gamma( """As `tf.random.gamma`, but handling stateful/stateless `seed`s.""" with tf.name_scope(name or 'gamma'): seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) alpha = tf.convert_to_tensor(alpha, dtype=dtype) beta = None if beta is None else tf.convert_to_tensor(beta, dtype=dtype) params_shape = ps.shape(alpha) @@ -306,6 +322,7 @@ def normal( shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype) seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) return tf.random.stateless_normal( shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype) @@ -319,6 +336,7 @@ def poisson( """As `tf.random.poisson`, but handling stateful/stateless `seed`s.""" with tf.name_scope(name or 'poisson'): seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) lam_shape = ps.shape(lam) sample_shape = ps.concat([shape, lam_shape], axis=0) return tf.random.stateless_poisson( @@ -332,6 +350,7 @@ def shuffle( """As `tf.random.shuffle`, but handling stateful/stateless `seed`s.""" with tf.name_scope(name or 'shuffle'): seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) sortkey = tf.random.stateless_uniform(shape=[ps.shape(value)[0]], seed=seed) return tf.gather(value, tf.argsort(sortkey)) @@ -346,5 +365,6 @@ def uniform( """As `tf.random.uniform`, but handling stateful/stateless `seed`s.""" with tf.name_scope(name or 'uniform'): seed = sanitize_seed(seed) + seed, _ = _get_seed_and_old_salt(seed) return tf.random.stateless_uniform( shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype) diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 05d380d218..9c2615522b 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -45,11 +45,11 @@ def setUp(self): def test_old_salt(self): if not tf1.control_flow_v2_enabled(): self.skipTest('TF2 only.') - with samplers.enable_old_salt(True): - seed = samplers.sanitize_seed(0, salt='nacl') - seed = samplers.sanitize_seed(seed, salt='kcl') - val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed) - self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val)) + seed = samplers.sanitize_seed(0, salt='nacl') + seed = samplers.enable_old_salt_for_seed(seed, True) + seed = samplers.sanitize_seed(seed, salt='kcl') + val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed) + self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val)) def test_new_style_jax_keys(self): if not JAX_MODE: