Skip to content

Commit

Permalink
Temporarily skips T5 input tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee committed Jul 17, 2023
1 parent bdf476f commit 3c00119
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions axlearn/common/input_t5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from typing import Dict, List, Tuple

import pytest
import seqio
import tensorflow as tf
from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -460,6 +461,9 @@ def _count_spans(x):
expected=tf.constant([32099, 12, 13, 14, 15, 32098, 19]),
),
)
@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_noise_span_to_unique_sentinel(
self, input_ids: tf.Tensor, noise_mask: tf.Tensor, expected: tf.Tensor
):
Expand Down Expand Up @@ -512,6 +516,9 @@ def test_noise_span_to_unique_sentinel(
expected=tf.constant([10, 11, 32099, 16, 17, 18, 32098]),
),
)
@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_nonnoise_span_to_unique_sentinel(
self, input_ids: tf.Tensor, noise_mask: tf.Tensor, expected: tf.Tensor
):
Expand Down Expand Up @@ -591,6 +598,9 @@ def test_nonnoise_span_to_unique_sentinel(
# ],
),
)
@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_apply_t5_mask(
self,
source_ids: tf.Tensor,
Expand Down Expand Up @@ -703,6 +713,9 @@ def wrapped(*args, seeds, **kwargs):
# ],
),
)
@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_make_t5_autoregressive_inputs(self, examples: List[Dict[str, tf.Tensor]], **kwargs):
tf.random.set_seed(1234)
source = fake_source(
Expand Down Expand Up @@ -734,6 +747,9 @@ def assert_strictly_increasing(seq):
# Due to how we construct our inputs, the interleaved IDs should be strictly increasing.
assert_strictly_increasing(_interleave(vocab, source_ids, target_labels))

@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_make_t5_autoregressive_inputs_validation(self):
with self.assertRaisesRegex(ValueError, "exceeds max target"):
make_t5_autoregressive_inputs(
Expand Down

0 comments on commit 3c00119

Please sign in to comment.