Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Nov 3, 2022
1 parent a225625 commit 270bfb0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
https://huggingface.co/models?filter=image-classification
"""

import json
import logging
import os
import sys
Expand Down Expand Up @@ -89,6 +90,13 @@ class DataTrainingArguments:
train_val_split: Optional[float] = field(
default=0.15, metadata={"help": "Percent to split off of train for validation."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
Expand Down Expand Up @@ -324,8 +332,11 @@ def main():

# Define our data preprocessing function. It takes an image file path as input and returns
# Write a note describing the resizing behaviour.
image_size = feature_extractor.size
image_size = (image_size, image_size) if isinstance(image_size, int) else image_size
if "shortest_edge" in feature_extractor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
image_size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
else:
image_size = (feature_extractor.size["height"], feature_extractor.size["width"])

def _train_transforms(image):
img_size = image_size
Expand Down Expand Up @@ -366,7 +377,12 @@ def val_transforms(example_batch):
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
train_dataset.set_transform(train_transforms)
train_dataset = train_dataset.map(
train_transforms,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

eval_dataset = None
if training_args.do_eval:
Expand All @@ -376,7 +392,12 @@ def val_transforms(example_batch):
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
# Set the validation transforms
eval_dataset.set_transform(val_transforms)
eval_dataset = eval_dataset.map(
val_transforms,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

predict_dataset = None
if training_args.do_predict:
Expand All @@ -388,7 +409,12 @@ def val_transforms(example_batch):
range(data_args.max_predict_samples)
)
# Set the test transforms
predict_dataset.set_transform(val_transforms)
predict_dataset = predict_dataset.map(
val_transforms,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

collate_fn = DefaultDataCollator(return_tensors="tf")

Expand All @@ -400,7 +426,9 @@ def val_transforms(example_batch):
def compute_metrics(p):
"""Computes accuracy on a batch of predictions"""
logits, label_ids = p
return metric.compute(predictions=np.argmax(logits, axis=1), references=label_ids)
predictions = np.argmax(logits, axis=-1)
metrics = metric.compute(predictions=predictions, references=label_ids)
return metrics

with training_args.strategy.scope():
if checkpoint is None:
Expand Down Expand Up @@ -513,12 +541,22 @@ def compute_metrics(p):
)

if training_args.do_eval:
model.evaluate(eval_dataset, steps=len(eval_dataset))
eval_predictions = model.predict(eval_dataset, steps=len(eval_dataset))
eval_metrics = compute_metrics((eval_predictions.logits, dataset["validation"]["labels"]))
logging.info("Eval metrics:")
for metric, value in eval_metrics.items():
logging.info(f"{metric}: {value:.3f}")

if training_args.output_dir is not None:
with open(os.path.join(training_args.output_dir, "all_results.json"), "w") as f:
f.write(json.dumps(eval_metrics))

if training_args.do_predict:
predictions = model.predict(predict_dataset, steps=len(predict_dataset))
test_metrics = compute_metrics(predictions, labels=predict_dataset.map(lambda x, y: y))
logging.info(f"Test metrics: {test_metrics}")
test_predictions = model.predict(predict_dataset, steps=len(predict_dataset))
test_metrics = compute_metrics((test_predictions.logits, dataset["validation"]["labels"]))
logging.info("Test metrics:")
for metric, value in test_metrics.items():
logging.info(f"{metric}: {value:.3f}")

if training_args.output_dir is not None and not training_args.push_to_hub:
# If we're not pushing to hub, at least save a local copy when we're done
Expand Down
26 changes: 26 additions & 0 deletions examples/tensorflow/test_tensorflow_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
"question-answering",
"summarization",
"translation",
"image-classification",
]
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
import run_clm
import run_image_classification
import run_mlm
import run_ner
import run_qa as run_squad
Expand Down Expand Up @@ -294,3 +296,27 @@ def test_run_translation(self):
run_translation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["bleu"], 30)

@slow
def test_run_image_classification(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_image_classification.py
--model_name_or_path google/vit-base-patch16-224-in21k
--dataset_name beans
--output_dir {tmp_dir}
--do_train
--do_eval
--overwrite_output_dir
--learning_rate 2e-5
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--max_steps 10
--seed 1337
--ignore_mismatched_sizes True
""".split()

with patch.object(sys, "argv", testargs):
run_image_classification.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["accuracy"], 0.7)

0 comments on commit 270bfb0

Please sign in to comment.