Skip to content

Commit

Permalink
Maybe cast seeds before use in span corruption preprocessors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600576320
  • Loading branch information
gauravmishra authored and t5-copybara committed Jan 23, 2024
1 parent f1f16c0 commit c7cdaa5
Show file tree
Hide file tree
Showing 43 changed files with 99 additions and 43 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/dataset_providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/dataset_providers_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/glue_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/mixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/postprocessors_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
18 changes: 17 additions & 1 deletion t5/data/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2050,6 +2050,20 @@ def prefix_lm(dataset, sequence_length, output_features,
return ds


def maybe_cast_seed(seed):
"""Cast seed to int64 if needed."""
def _maybe_cast_seed(s):
if s.dtype == tf.int64 or s.dtype == tf.int32:
return s
return tf.cast(s, tf.int64)

# seed is either a Tensor or a pair of Tensors.
if isinstance(seed, tf.Tensor):
return _maybe_cast_seed(seed)
else:
return _maybe_cast_seed(seed[0]), _maybe_cast_seed(seed[1])


def full_lm(dataset, sequence_length, output_features):
"""Full language modeling objective with EOS only at document boundaries."""
ds = dataset
Expand Down Expand Up @@ -2119,6 +2133,7 @@ def single_example_select_random_chunk(
if max_length is None:
raise ValueError('Must specify max_length or sequence_length.')

seed = maybe_cast_seed(seed)
seeds = tf.unstack(tf.random.experimental.stateless_split(seed))
tokens = features[feature_key]
n_tokens = tf.shape(tokens)[0]
Expand Down Expand Up @@ -2790,6 +2805,7 @@ def single_example_denoise(features: FeatureType,
raise ValueError(
f"passthrough keys cannot contain '{input_feature_key}' or 'targets'")

seed = maybe_cast_seed(seed)
seeds = tf.unstack(tf.random.experimental.stateless_split(seed, 6))
tokens = features['targets']
vocabulary = output_features['targets'].vocabulary
Expand Down
42 changes: 41 additions & 1 deletion t5/data/preprocessors_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,8 @@

from absl.testing import absltest
import gin
import jax
import numpy as np
import seqio
from seqio import test_utils
from t5.data import preprocessors as prep
Expand Down Expand Up @@ -2418,6 +2420,44 @@ def test_span_corruption_passthrough_fail(self):
merge_examples_to_reduce_padding=True,
passthrough_feature_keys=['passthrough'])

def test_cast_seeds_works_with_jax_rngs(self):
jax_rng = jax.random.PRNGKey(0)
cast_rng = prep.maybe_cast_seed(jax_rng)
tf_seed = tf.random.experimental.stateless_split(cast_rng)
expected_seed = tf.constant(
[
[-6856063536328932722, -6061095977527908230],
[6624627857932008774, -9000103707649019216],
],
shape=(2, 2),
dtype=tf.int64,
)
tf.assert_equal(tf_seed, expected_seed)

def test_cast_seeds_works_with_tf_rngs(self):
# This imitates seqio seed generation, which is how seeds are commonly
# passed to t5 preprocessors.
base_seed = 42
random_ds_seeds = np.arange(base_seed, base_seed + 2).reshape(-1, 2)
random_ds_seeds = tuple(tuple(s) for s in random_ds_seeds)
seed_ds = tf.nest.map_structure(
tf.data.experimental.RandomDataset, random_ds_seeds
)
range_ds = tf.data.Dataset.from_tensor_slices(range(1))
zip_ds = tf.data.Dataset.zip(range_ds, seed_ds)
for d in zip_ds:
seed_tuple = d[1][0]
cast_rng = prep.maybe_cast_seed(seed_tuple)
tf_seed = tf.random.experimental.stateless_split(cast_rng)
expected_seed = tf.constant(
[
[-9175096281216858533, 7122269515015050198],
[-2581824382145232558, -4001119725515448663],
],
dtype=tf.int64,
)
tf.assert_equal(tf_seed, expected_seed)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion t5/data/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/eval_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/eval_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/qa_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/qa_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/scoring_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/evaluation/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/gin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/gin/learning_rate_schedules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/gin/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/gin/objectives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/gin/sequence_lengths/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/hf_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/mesh_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/mesh_transformer_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/mesh_transformer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/mtf_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/t5_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/dump_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/dump_task_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/parse_tb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/prepare_glue_submission.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/scripts/transform_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/seqio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion t5/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The T5 Authors.
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit c7cdaa5

Please sign in to comment.