Skip to content

Commit

Permalink
add test case for data collator
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhui Dih Lee authored and Rhui Dih Lee committed Jul 16, 2024
1 parent f5fa856 commit 808fd63
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
DataCollatorWithFlattening,
default_data_collator,
is_tf_available,
is_torch_available,
Expand Down Expand Up @@ -1531,6 +1532,19 @@ def test_data_collator_with_padding(self):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 8))

def test_data_collator_with_flattening(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": [10, 11, 12]}, {"input_ids": [20, 21, 22, 23, 24, 25]}, {"input_ids": [30, 31, 32, 33, 34, 35, 36]}]

data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36])
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])

def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
Expand Down

0 comments on commit 808fd63

Please sign in to comment.