Skip to content

Commit

Permalink
Fixed datasets error in v2.3.x (intel#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng authored Jun 24, 2022
1 parent eaf9cfe commit 45e6a1a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def compute_metrics(preds, label_ids):
collate_fn=data_collator,
drop_remainder=drop_remainder,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in dataset.column_names else None,
# datasets v2.3.x need "labels", not "label"
label_cols=["labels", "label"] if "label" in dataset.column_names else None,
)
tf_data[key] = data
# endregion
Expand Down Expand Up @@ -593,4 +594,4 @@ def compute_metrics(preds, label_ids):


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ def compute_metrics(preds, label_ids):
collate_fn=data_collator,
drop_remainder=drop_remainder,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in dataset.column_names else None,
# datasets v2.3.x need "labels", not "label"
label_cols=["labels", "label"] if "label" in dataset.column_names else None,
)
tf_data[key] = data
# endregion
Expand Down Expand Up @@ -619,4 +620,4 @@ def compute_metrics(preds, label_ids):


if __name__ == "__main__":
main()
main()
3 changes: 2 additions & 1 deletion tests/test_tf_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def preprocess_function(examples):
collate_fn=data_collator,
drop_remainder=False,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in dataset.column_names else None,
# datasets v2.3.x need "labels", not "label"
label_cols=["label", "labels"] if "label" in dataset.column_names else None,
)
parser = HfArgumentParser(TFTrainingArguments)
self.args = parser.parse_args_into_dataclasses(args=["--output_dir", "./quantized_model",
Expand Down
7 changes: 4 additions & 3 deletions tests/test_tf_quantization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import os
import shutil
import tensorflow as tf
import tensorflow as tf
import unittest
from datasets import load_dataset, load_metric
from nlp_toolkit import (
Expand Down Expand Up @@ -43,14 +43,15 @@ def preprocess_function(examples):
data_collator = DefaultDataCollator(return_tensors="tf")
dataset = raw_datasets.select(range(10))
self.dummy_dataset = dataset.to_tf_dataset(
columns=[col for col in dataset.column_names if col not in
columns=[col for col in dataset.column_names if col not in
set(non_label_column_names + ["label"])],
shuffle=False,
batch_size=2,
collate_fn=data_collator,
drop_remainder=False,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in dataset.column_names else None,
# datasets v2.3.x need "labels", not "label"
label_cols=["label", "labels"] if "label" in dataset.column_names else None,
)


Expand Down

0 comments on commit 45e6a1a

Please sign in to comment.