From 37a585f944bc0fb5228e373639dd67b081e72ee5 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 3 Nov 2022 12:40:53 +0000 Subject: [PATCH] Add tests --- .../run_image_classification.py | 58 +++++++++++++++---- .../tensorflow/test_tensorflow_examples.py | 26 +++++++++ 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/examples/tensorflow/image-classification/run_image_classification.py b/examples/tensorflow/image-classification/run_image_classification.py index 336e5dcdbfedec..5e3a8a3ce14222 100644 --- a/examples/tensorflow/image-classification/run_image_classification.py +++ b/examples/tensorflow/image-classification/run_image_classification.py @@ -19,6 +19,7 @@ https://huggingface.co/models?filter=image-classification """ +import json import logging import os import sys @@ -89,6 +90,13 @@ class DataTrainingArguments: train_val_split: Optional[float] = field( default=0.15, metadata={"help": "Percent to split off of train for validation."} ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) max_train_samples: Optional[int] = field( default=None, metadata={ @@ -324,8 +332,11 @@ def main(): # Define our data preprocessing function. It takes an image file path as input and returns # Write a note describing the resizing behaviour. - image_size = feature_extractor.size - image_size = (image_size, image_size) if isinstance(image_size, int) else image_size + if "shortest_edge" in feature_extractor.size: + # We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable. + image_size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"]) + else: + image_size = (feature_extractor.size["height"], feature_extractor.size["width"]) def _train_transforms(image): img_size = image_size @@ -366,7 +377,12 @@ def val_transforms(example_batch): train_dataset = dataset["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) - train_dataset.set_transform(train_transforms) + train_dataset = train_dataset.map( + train_transforms, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) eval_dataset = None if training_args.do_eval: @@ -376,7 +392,12 @@ def val_transforms(example_batch): if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) # Set the validation transforms - eval_dataset.set_transform(val_transforms) + eval_dataset = eval_dataset.map( + val_transforms, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) predict_dataset = None if training_args.do_predict: @@ -388,7 +409,12 @@ def val_transforms(example_batch): range(data_args.max_predict_samples) ) # Set the test transforms - predict_dataset.set_transform(val_transforms) + predict_dataset = predict_dataset.map( + val_transforms, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) collate_fn = DefaultDataCollator(return_tensors="tf") @@ -400,7 +426,9 @@ def val_transforms(example_batch): def compute_metrics(p): """Computes accuracy on a batch of predictions""" logits, label_ids = p - return metric.compute(predictions=np.argmax(logits, axis=1), references=label_ids) + predictions = np.argmax(logits, axis=-1) + metrics = metric.compute(predictions=predictions, references=label_ids) + return metrics with training_args.strategy.scope(): if checkpoint is None: @@ -513,12 +541,22 @@ def compute_metrics(p): ) if training_args.do_eval: - model.evaluate(eval_dataset, steps=len(eval_dataset)) + eval_predictions = model.predict(eval_dataset, steps=len(eval_dataset)) + eval_metrics = compute_metrics((eval_predictions.logits, dataset["validation"]["labels"])) + logging.info("Eval metrics:") + for metric, value in eval_metrics.items(): + logging.info(f"{metric}: {value:.3f}") + + if training_args.output_dir is not None: + with open(os.path.join(training_args.output_dir, "all_results.json"), "w") as f: + f.write(json.dumps(eval_metrics)) if training_args.do_predict: - predictions = model.predict(predict_dataset, steps=len(predict_dataset)) - test_metrics = compute_metrics(predictions, labels=predict_dataset.map(lambda x, y: y)) - logging.info(f"Test metrics: {test_metrics}") + test_predictions = model.predict(predict_dataset, steps=len(predict_dataset)) + test_metrics = compute_metrics((test_predictions.logits, dataset["validation"]["labels"])) + logging.info("Test metrics:") + for metric, value in test_metrics.items(): + logging.info(f"{metric}: {value:.3f}") if training_args.output_dir is not None and not training_args.push_to_hub: # If we're not pushing to hub, at least save a local copy when we're done diff --git a/examples/tensorflow/test_tensorflow_examples.py b/examples/tensorflow/test_tensorflow_examples.py index f4b383eabe5303..1b3b52efcccc85 100644 --- a/examples/tensorflow/test_tensorflow_examples.py +++ b/examples/tensorflow/test_tensorflow_examples.py @@ -38,6 +38,7 @@ "question-answering", "summarization", "translation", + "image-classification", ] ] sys.path.extend(SRC_DIRS) @@ -45,6 +46,7 @@ if SRC_DIRS is not None: import run_clm + import run_image_classification import run_mlm import run_ner import run_qa as run_squad @@ -294,3 +296,27 @@ def test_run_translation(self): run_translation.main() result = get_results(tmp_dir) self.assertGreaterEqual(result["bleu"], 30) + + @slow + def test_run_image_classification(self): + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_image_classification.py + --model_name_or_path google/vit-base-patch16-224-in21k + --dataset_name beans + --output_dir {tmp_dir} + --do_train + --do_eval + --overwrite_output_dir + --learning_rate 2e-5 + --per_device_train_batch_size 8 + --per_device_eval_batch_size 8 + --max_steps 10 + --seed 1337 + --ignore_mismatched_sizes True + """.split() + + with patch.object(sys, "argv", testargs): + run_image_classification.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["accuracy"], 0.7)