Skip to content

Commit 8f20e61

Browse files
authored
Update feature selection in to_tf_dataset (#21935)
* Update feature selection * Check compatibility with datasets version * Checkout from datasets main
1 parent 345a137 commit 8f20e61

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

docs/source/en/tasks/image_classification.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,12 @@ Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Data
385385
```py
386386
>>> # converting our train dataset to tf.data.Dataset
387387
>>> tf_train_dataset = food["train"].to_tf_dataset(
388-
... columns=["pixel_values"], label_cols=["label"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
388+
... columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator
389389
... )
390390

391391
>>> # converting our test dataset to tf.data.Dataset
392392
>>> tf_eval_dataset = food["test"].to_tf_dataset(
393-
... columns=["pixel_values"], label_cols=["label"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
393+
... columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator
394394
... )
395395
```
396396

docs/source/es/training.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,15 @@ A continuación, convierte los datasets tokenizados en datasets de TensorFlow co
173173
```py
174174
>>> tf_train_dataset = small_train_dataset.to_tf_dataset(
175175
... columns=["attention_mask", "input_ids", "token_type_ids"],
176-
... label_cols=["labels"],
176+
... label_cols="labels",
177177
... shuffle=True,
178178
... collate_fn=data_collator,
179179
... batch_size=8,
180180
... )
181181

182182
>>> tf_validation_dataset = small_eval_dataset.to_tf_dataset(
183183
... columns=["attention_mask", "input_ids", "token_type_ids"],
184-
... label_cols=["labels"],
184+
... label_cols="labels",
185185
... shuffle=False,
186186
... collate_fn=data_collator,
187187
... batch_size=8,

docs/source/pt/training.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,15 @@ Especifique suas entradas em `columns` e seu rótulo em `label_cols`:
205205
```py
206206
>>> tf_train_dataset = small_train_dataset.to_tf_dataset(
207207
... columns=["attention_mask", "input_ids", "token_type_ids"],
208-
... label_cols=["labels"],
208+
... label_cols="labels",
209209
... shuffle=True,
210210
... collate_fn=data_collator,
211211
... batch_size=8,
212212
... )
213213

214214
>>> tf_validation_dataset = small_eval_dataset.to_tf_dataset(
215215
... columns=["attention_mask", "input_ids", "token_type_ids"],
216-
... label_cols=["labels"],
216+
... label_cols="labels",
217217
... shuffle=False,
218218
... collate_fn=data_collator,
219219
... batch_size=8,

src/transformers/modeling_tf_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,12 @@ def prepare_tf_dataset(
14131413
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
14141414
label_cols = [col for col in output_columns if col in model_labels]
14151415

1416+
# Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols`
1417+
# were a single element list, the returned element spec would be a single element. Now, passing [feature]
1418+
# will return a dict structure {"feature": feature}, and passing a single string will return a single element.
1419+
feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols
1420+
label_cols = label_cols[0] if len(label_cols) == 1 else label_cols
1421+
14161422
if drop_remainder is None:
14171423
drop_remainder = shuffle
14181424
tf_dataset = dataset.to_tf_dataset(

0 commit comments

Comments
 (0)