diff --git a/t5/data/preprocessors.py b/t5/data/preprocessors.py index c909b38b..8a3ebbcc 100644 --- a/t5/data/preprocessors.py +++ b/t5/data/preprocessors.py @@ -2744,17 +2744,19 @@ def __call__(self, tokens: tf.Tensor, noise_mask: tf.Tensor, vocabulary, """Computes the target tokens. Seeds should have shape [2, 2].""" -def single_example_denoise(features: FeatureType, - seed: tf.Tensor, - *, - output_features: Mapping[str, Any], - noise_density: float, - noise_mask_fn: DenoiseNoiseMaskFn, - inputs_fn: DenoiseInputsFn, - targets_fn: Optional[DenoiseTargetsFn] = None, - passthrough_feature_keys: Optional[ - Sequence[str]] = None, - input_feature_key: str = 'inputs') -> FeatureType: +def single_example_denoise( + features: FeatureType, + seed: tf.Tensor, + *, + output_features: Mapping[str, Any], + noise_density: float, + noise_mask_fn: DenoiseNoiseMaskFn, + inputs_fn: DenoiseInputsFn, + targets_fn: Optional[DenoiseTargetsFn] = None, + passthrough_feature_keys: Optional[Sequence[str]] = None, + input_feature_key: str = 'inputs', + batch_size: int | None = None, +) -> FeatureType: """Preprocessing function for self-supervised denoising tasks. This function takes a dataset containing "targets" sequences, @@ -2796,6 +2798,8 @@ def single_example_denoise(features: FeatureType, targets_fn: a function from (tokens, noise_mask, vocabulary) -> tokens passthrough_feature_keys: names of additional features to include in output input_feature_key: name of feature to use as inputs + batch_size: an optional int indicating batch size if `features` is a dict of + batched features. Returns: A preprocessed features. @@ -2814,7 +2818,14 @@ def single_example_denoise(features: FeatureType, raise ValueError( 'denoise creates inputs based on tokenized targets but was applied ' 'to a task that uses different vocabularies for inputs and targets.') - noise_mask = noise_mask_fn(tf.size(tokens), noise_density, seeds=seeds[:2]) + if batch_size: + # This step will fail if the noise_mask_fn, inputs_fn or targets_fn don't + # support a batch_size arg. + noise_mask_fn = functools.partial(noise_mask_fn, batch_size=batch_size) + inputs_fn = functools.partial(inputs_fn, batch_size=batch_size) + targets_fn = functools.partial(targets_fn, batch_size=batch_size) + length = tf.size(tokens) // (batch_size or 1) + noise_mask = noise_mask_fn(length, noise_density, seeds=seeds[:2]) inputs = inputs_fn(tokens, noise_mask, vocabulary, seeds=seeds[2:4]) if targets_fn: targets = targets_fn(tokens, noise_mask, vocabulary, seeds=seeds[4:6]) @@ -2916,11 +2927,14 @@ def regular_noise_mask(length, @gin.configurable() -def random_spans_noise_mask(length, - noise_density, - seeds, - mean_noise_span_length=3.0, - random_roll=False): +def random_spans_noise_mask( + length, + noise_density, + seeds, + mean_noise_span_length=3.0, + random_roll=False, + batch_size=None, +): """Noise mask consisting of random spans of noise tokens. The number of noise tokens and the number of noise spans and non-noise spans @@ -2943,9 +2957,12 @@ def random_spans_noise_mask(length, of masked positions. Specifically, when random_roll is False (default) and a single span is enough to satisfy the noise density requirement, this fuction masks only the last few positions. + batch_size: an int32; if set, a batch of masks of shape [batch_size, length] + is returned. Returns: - a boolean tensor with shape [length] + a boolean tensor with shape [length] or [batch_size, length] if batch_size + is set. """ if noise_density == 0.0: @@ -2966,6 +2983,11 @@ def to_float(x): # avoid degeneracy by ensuring positive number of noise spans num_noise_spans = tf.maximum(num_noise_spans, 1) num_nonnoise_tokens = length - num_noise_tokens + if batch_size: + # Create seeds to generate masks for each row. + seeds = tf.unstack( + tf.random.experimental.stateless_split(seeds[0], batch_size * 2) + ) # pick the lengths of the noise spans and the non-noise spans def _random_segmentation(num_items, num_segments, seed): """Partition a sequence of items randomly into non-empty segments. @@ -2986,29 +3008,35 @@ def _random_segmentation(num_items, num_segments, seed): segment_id = tf.cumsum(first_in_segment) segment_length = tf.math.segment_sum(tf.ones_like(segment_id), segment_id) return segment_length - noise_span_lengths = _random_segmentation( - num_noise_tokens, num_noise_spans, seeds[0]) - nonnoise_span_lengths = _random_segmentation( - num_nonnoise_tokens, num_noise_spans, seeds[1]) - interleaved_span_lengths = tf.reshape( - tf.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), - [num_noise_spans * 2]) - span_starts = tf.cumsum(interleaved_span_lengths)[:-1] - span_start_indicator = tf.math.unsorted_segment_sum( - tf.ones_like(span_starts), span_starts, length) - span_num = tf.cumsum(span_start_indicator) - is_noise = tf.equal(span_num % 2, 1) - - mask = is_noise[:orig_length] - - if random_roll: - roll_seed = (seeds[0][0]+seeds[1][1], seeds[0][1]-seeds[1][0]) # new seed. - # Roll the mask by a random offset e.g. for offset=2: [1,2,3,4] => [3,4,1,2] - offset = tf.random.stateless_uniform( - [1], seed=roll_seed, dtype=tf.int32, minval=0, maxval=length)[0] - mask = tf.roll(mask, shift=offset, axis=0) - - return mask + masks = [] + for i in range(batch_size or 1): + noise_span_lengths = _random_segmentation( + num_noise_tokens, num_noise_spans, seeds[2 * i]) + nonnoise_span_lengths = _random_segmentation( + num_nonnoise_tokens, num_noise_spans, seeds[2 * i + 1]) + interleaved_span_lengths = tf.reshape( + tf.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2]) + span_starts = tf.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = tf.math.unsorted_segment_sum( + tf.ones_like(span_starts), span_starts, length) + span_num = tf.cumsum(span_start_indicator) + is_noise = tf.equal(span_num % 2, 1) + + mask = is_noise[:orig_length] + + if random_roll: + roll_seed = (seeds[0][0]+seeds[1][1], seeds[0][1]-seeds[1][0]) # new seed + # Roll the mask by a random offset e.g. for offset=2: [1,2,3,4] => + # [3,4,1,2] + offset = tf.random.stateless_uniform( + [1], seed=roll_seed, dtype=tf.int32, minval=0, maxval=length)[0] + mask = tf.roll(mask, shift=offset, axis=0) + masks.append(mask) + + if not batch_size: + return masks[0] + return tf.stack(masks, axis=0) @gin.configurable() @@ -3110,7 +3138,9 @@ def nonnoise_span_to_sentinel(tokens, noise_mask, vocabulary, seeds): @gin.configurable() -def noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds): +def noise_span_to_unique_sentinel( + tokens, noise_mask, vocabulary, seeds, batch_size=None +): """Replace each run of consecutive noise tokens with a different sentinel. The idea here is to be able to align the dropped spans in the inputs @@ -3132,28 +3162,67 @@ def noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds): noise_mask: a boolean Tensor with the same shape as tokens vocabulary: a vocabulary.Vocabulary seeds: an unused int32 Tensor + batch_size: an optional int32; if tokens are batched. + Returns: a Tensor with the same shape and dtype as tokens """ del seeds - - prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]]) + if batch_size: + def shift_batched_right_by_one(arr, fill_value): + if not (arr.dtype.is_integer or arr.dtype.is_floating): + raise ValueError(f'Only numeric types are supported. Got: {arr.dtype}') + # tf.roll wraps around the axis. + rolled = tf.roll(arr, shift=1, axis=1) + + # Zero out the first position by multiplying with [0, 1, 1, ..., 1]. + depth = tf.shape(arr)[1] + mask = tf.one_hot( + 0, depth=depth, on_value=0, off_value=1, dtype=arr.dtype + ) + # Tile the mask to match batch size + shape = tf.shape(arr) + mask = tf.tile(tf.expand_dims(mask, axis=0), [batch_size, 1]) + # Broadcast mask to match shape of rolled + for _ in range(len(shape) - 2): + mask = tf.expand_dims(mask, axis=-1) + return rolled * mask + (1 - mask) * fill_value + int_mask = tf.cast(noise_mask, tf.int32) + shifted_mask = shift_batched_right_by_one(int_mask, fill_value=0) + prev_token_is_noise = tf.cast(shifted_mask, tf.bool) + else: + prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]]) first_noise_tokens = tf.logical_and( noise_mask, tf.logical_not(prev_token_is_noise)) subsequent_noise_tokens = tf.logical_and(noise_mask, prev_token_is_noise) - sentinel = sentinel_id(vocabulary) + 1 - tf.cumsum( - tf.cast(first_noise_tokens, tokens.dtype)) + sentinel = ( + sentinel_id(vocabulary) + + 1 + - tf.cumsum(tf.cast(first_noise_tokens, tokens.dtype), axis=-1) + ) tokens = tf.where(first_noise_tokens, sentinel, tokens) - return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens)) + if not batch_size: + return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens)) + return tf.stack([ + tf.boolean_mask(t, tf.logical_not(n)) + for t, n in zip(tf.unstack(tokens), tf.unstack(subsequent_noise_tokens)) + ]) @gin.configurable() -def nonnoise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds): +def nonnoise_span_to_unique_sentinel( + tokens, noise_mask, vocabulary, seeds, batch_size=None +): return noise_span_to_unique_sentinel( - tokens, tf.logical_not(noise_mask), vocabulary, seeds) + tokens, + tf.logical_not(noise_mask), + vocabulary, + seeds, + batch_size=batch_size, + ) @gin.configurable()