Skip to content

Commit

Permalink
[tests] fix typos in inputs (huggingface#6818)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored Aug 30, 2020
1 parent 22933e6 commit 563485b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
14 changes: 7 additions & 7 deletions tests/test_tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,29 @@ def default_tokenizer_fast(self):

@require_torch
def test_prepare_seq2seq_batch(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]

for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)

self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset

# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no labels
Expand All @@ -102,7 +102,7 @@ def test_seq2seq_batch_empty_target_text(self):

@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_seq2seq_batch_not_longer_than_maxlen(self):
@require_torch
def test_special_tokens(self):

src_text = ["A long paragraph for summrization."]
src_text = ["A long paragraph for summarization."]
tgt_text = [
"Summary of the text.",
]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def test_eos_treatment(self):

def test_prepare_seq2seq_batch(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
Expand All @@ -135,15 +135,15 @@ def test_prepare_seq2seq_batch(self):
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)

self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)

# Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, [])

def test_empty_target_text(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
Expand All @@ -153,7 +153,7 @@ def test_empty_target_text(self):

def test_max_target_length(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
Expand All @@ -180,9 +180,9 @@ def test_outputs_not_longer_than_maxlen(self):

def test_eos_in_input(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization. </s>"]
src_text = ["A long paragraph for summarization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1]

batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
Expand Down

0 comments on commit 563485b

Please sign in to comment.