Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Transformer blog example #9789

Merged
merged 44 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2215479
Add PBT transformer example
Jul 21, 2020
882084b
mkdir if not exists
Jul 21, 2020
c047395
added reporter
Jul 22, 2020
bc68c8a
Added comments
Jul 23, 2020
2499032
(cont).
Jul 23, 2020
a563a28
evaluate much more often
Jul 23, 2020
5d044f4
(cont.) evaluate much more often
Jul 23, 2020
7ce40d9
Smaller batch sizes to avoid cuda memory out of error on V100 GPUs.
Jul 23, 2020
85ccad2
Start with random parameters
Jul 23, 2020
997c028
Added checkpointing to pbt example
Jul 24, 2020
ea04313
Updates
amogkam Jul 25, 2020
e67463c
more updates
amogkam Jul 29, 2020
97e64e4
fix wandb and set keep_checkpoint_num
amogkam Jul 29, 2020
c196480
add init
amogkam Jul 29, 2020
a2e8319
formatting
amogkam Jul 29, 2020
c687f11
cleanup and example link
amogkam Jul 29, 2020
1829d5e
adding smoke test
amogkam Jul 29, 2020
26d6292
add example doc
amogkam Jul 29, 2020
627f293
make smoke test shorter
amogkam Jul 29, 2020
cad0c0c
remove emoji
amogkam Jul 29, 2020
f4643b8
docs
amogkam Jul 29, 2020
bb1f520
updates
amogkam Jul 31, 2020
6762024
merging master
amogkam Jul 31, 2020
083d5b9
add test to build
amogkam Jul 31, 2020
f523553
wip
amogkam Jul 31, 2020
e3c86d6
revert back tutorials overview
amogkam Jul 31, 2020
4708a72
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Jul 31, 2020
3fe2892
formatting
amogkam Jul 31, 2020
ea5306e
fixes
amogkam Jul 31, 2020
8c0703e
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Aug 3, 2020
70d43d2
more updates
amogkam Aug 3, 2020
65ddd3d
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Aug 3, 2020
f2f13ce
adding transformers dep for tests
amogkam Aug 3, 2020
6581ece
lint
amogkam Aug 4, 2020
3c31b87
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Aug 4, 2020
5b5edb6
update test
amogkam Aug 4, 2020
ee0b7ad
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Aug 4, 2020
9b8785f
doc fix
amogkam Aug 4, 2020
f715895
fix data downloading error
amogkam Aug 4, 2020
f5f7eea
make test large
amogkam Aug 4, 2020
28b19f2
Merge branch 'master' of https://github.com/ray-project/ray into tune…
amogkam Aug 4, 2020
68978dd
adding mock test data and using smaller model
amogkam Aug 5, 2020
e8af5e9
lint
amogkam Aug 5, 2020
3fe4b04
make test even shorter
amogkam Aug 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ci/jenkins_tests/run_tune_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
--smoke-test

$SUPPRESS_OUTPUT docker run --rm --shm-size${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_transformers/pbt_transformers.py \
--smoke-test

amogkam marked this conversation as resolved.
Show resolved Hide resolved
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
python /ray/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py \
--smoke-test
Expand Down
6 changes: 6 additions & 0 deletions doc/source/tune/examples/pbt_transformers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:orphan:

pbt_transformers_example
~~~~~~~~~~~~~~~~~~~~~~~~

.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/pbt_transformers.py
amogkam marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions python/ray/tune/examples/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ LightGBM Example

- `lightgbm_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/lightgbm_example.py>`__: Trains a basic LightGBM model with Tune with the function-based API and a LightGBM callback.

:hugging_face: Transformers Example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this actually work? interesting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it working for the sphinx docs but cant figure out how to do it for Github rst

------------------------------

- `pbt_transformers <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_transformers/pbt_transformers.py>`__: Fine-tunes a :hugging_face: transformer with Tune Population Based Training.


Contributed Examples
--------------------
Expand Down
Empty file.
206 changes: 206 additions & 0 deletions python/ray/tune/examples/pbt_transformers/pbt_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# flake8: noqa
# yapf: disable

import os
import ray
from ray.tune import CLIReporter
from ray.tune.schedulers import PopulationBasedTraining

from ray import tune
import trainer
amogkam marked this conversation as resolved.
Show resolved Hide resolved
from ray.tune.examples.pbt_transformers.utils import build_compute_metrics_fn, download_data

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
Trainer,
TrainingArguments,
glue_output_modes,
glue_tasks_num_labels,
)


def get_trainer(model_name_or_path, train_dataset, eval_dataset, task_name, training_args, wandb_args=None):
try:
num_labels = glue_tasks_num_labels[task_name]
except KeyError:
raise ValueError("Task not found: %s" % (task_name))

config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
finetuning_task=task_name,
)

model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
config=config,
)
tune_trainer = trainer.TuneTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(task_name),
wandb_args=wandb_args
)

return tune_trainer


# __train_begin__
def train_transformer(config, checkpoint=None):
data_args = DataTrainingArguments(
task_name=config["task_name"],
data_dir=config["data_dir"]
)
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train", cache_dir=config["data_dir"])
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=config["data_dir"])
eval_dataset = eval_dataset[:len(eval_dataset) // 2]
training_args = TrainingArguments(
output_dir=tune.get_trial_dir(),
learning_rate=config["learning_rate"],
do_train=True,
do_eval=True,
evaluate_during_training=True,
eval_steps=(len(train_dataset) // config["per_gpu_train_batch_size"]) + 1,
save_steps=0, # We explicitly set save here to 0, and do saving in evaluate instead
num_train_epochs=config["num_epochs"],
max_steps=config["max_steps"],
per_device_train_batch_size=config["per_gpu_train_batch_size"],
per_device_eval_batch_size=config["per_gpu_val_batch_size"],
warmup_steps=0,
weight_decay=config["weight_decay"],
logging_dir="./logs",
)

# Arguments for W&B.
name = tune.get_trial_name()
wandb_args = {
"project_name": "transformers_pbt",
"watch": "false", # Either set to gradient, false, or all
"run_name": name,
}

tune_trainer = get_trainer(config["model_name"], train_dataset, eval_dataset, config["task_name"], training_args,
wandb_args=wandb_args)
tune_trainer.train(checkpoint if checkpoint is not None and len(checkpoint) > 0 else config["model_name"])


# __train_end__


# __tune_begin__
def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):
ray.init(address="auto" if not smoke_test else None, log_to_driver=False)
data_dir = os.path.abspath(os.path.join(os.getcwd(), "./data"))
if not os.path.exists(data_dir):
os.mkdir(data_dir, 0o755)
model_name = "bert-base-uncased"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for smoke-test can we actually use a super tiny model?

task_name = "rte"

task_data_dir = os.path.join(data_dir, task_name.upper())

# Download and cache tokenizer, model, and features
print("Downloading and caching Tokenizer")

# Triggers tokenizer download to cache
AutoTokenizer.from_pretrained(model_name)
print("Downloading and caching pre-trained model")

# Triggers model download to cache
AutoModelForSequenceClassification.from_pretrained(
model_name,
)

# Download data.
download_data(task_name, data_dir)

config = {
"model_name": model_name,
"task_name": task_name,
"data_dir": task_data_dir,
"per_gpu_val_batch_size": 32,
"per_gpu_train_batch_size": tune.choice([16, 32, 64]),
"learning_rate": tune.uniform(1e-5, 5e-5),
"weight_decay": tune.uniform(0.0, 0.3),
"num_epochs": tune.choice([2, 3, 4, 5]),
"max_steps": -1 if not smoke_test else 3,
}

scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="eval_acc",
mode="max",
perturbation_interval=1,
hyperparam_mutations={
"weight_decay": lambda: tune.uniform(0.0, 0.3).func(None),
"learning_rate": lambda: tune.uniform(1e-5, 5e-5).func(None),
"per_gpu_train_batch_size": [16, 32, 64],
})

reporter = CLIReporter(
parameter_columns={
"weight_decay": "w_decay",
"learning_rate": "lr",
"per_gpu_train_batch_size": "train_bs/gpu",
"num_epochs": "num_epochs"},
metric_columns=["eval_acc", "eval_loss", "epoch", "training_iteration"])

analysis = tune.run(
train_transformer,
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
keep_checkpoints_num=3,
checkpoint_score_attr="training_iteration",
progress_reporter=reporter,
name="tune_transformer_pbt")

if not smoke_test:
test_best_model(analysis, config["model_name"], config["task_name"], config["data_dir"])


# __tune_end__

def test_best_model(analysis, model_name, task_name, data_dir):
data_args = DataTrainingArguments(
task_name=task_name,
data_dir=data_dir
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

best_config = analysis.get_best_config(metric="eval_acc", mode="max")
print(best_config)
best_checkpoint = analysis.get_best_trial(metric="eval_acc", mode="max").checkpoint.value
print(best_checkpoint)
best_model = AutoModelForSequenceClassification.from_pretrained(best_checkpoint).to("cuda")

test_args = TrainingArguments(
output_dir="./best_model_results",
)
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=data_dir)
test_dataset = test_dataset[len(test_dataset) // 2:]

test_trainer = Trainer(best_model, test_args, compute_metrics=build_compute_metrics_fn(task_name))

metrics = test_trainer.evaluate(test_dataset)
print(metrics)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()

if args.smoke_test:
tune_transformer(num_samples=1, gpus_per_trial=0, smoke_test=True)
else:
# You can change the number of GPUs here:
tune_transformer(num_samples=8, gpus_per_trial=1)
75 changes: 75 additions & 0 deletions python/ray/tune/examples/pbt_transformers/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import os
import torch
from torch.utils.data import Dataset
import transformers
from ray import tune
from typing import Dict, Optional, Tuple

from transformers.file_utils import is_torch_tpu_available
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import wandb

logger = logging.getLogger(__name__)
"""A Trainer class integrated with Tune.
The only changes to the original transformers.Trainer are:
- Report eval metrics to Tune
- Save state using Tune's checkpoint directories
"""


class TuneTransformerTrainer(transformers.Trainer):
def __init__(self, *args, wandb_args=None, **kwargs):
self.wandb_args = wandb_args
super().__init__(*args, **kwargs)

def get_optimizers(
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
self.current_optimizer, self.current_scheduler = super(
).get_optimizers(num_training_steps)
return (self.current_optimizer, self.current_scheduler)

def evaluate(self,
eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self._prediction_loop(
eval_dataloader, description="Evaluation")
self._log(output.metrics)

tune.report(**output.metrics)

self.save_state()

return output.metrics

def save_state(self):
self.args.output_dir = tune.make_checkpoint_dir()
output_dir = os.path.join(
self.args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
self.save_model(output_dir)
if self.is_world_master():
torch.save(self.current_optimizer.state_dict(),
os.path.join(output_dir, "optimizer.pt"))
torch.save(self.current_scheduler.state_dict(),
os.path.join(output_dir, "scheduler.pt"))
tune.save_checkpoint(output_dir)

def _setup_wandb(self):
if self.is_world_master() and self.wandb_args is not None:
wandb.init(
project=self.wandb_args["project_name"],
name=self.wandb_args["run_name"],
id=self.wandb_args["run_name"],
config=vars(self.args),
reinit=True,
allow_val_change=True,
resume=self.wandb_args["run_name"])
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available(
) and self.wandb_args["watch"] != "false":
wandb.watch(
self.model,
log=self.wandb_args["watch"],
log_freq=max(100, self.args.logging_steps))
43 changes: 43 additions & 0 deletions python/ray/tune/examples/pbt_transformers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Utilities to load and cache data."""

import os
from typing import Callable, Dict
import numpy as np
from transformers import EvalPrediction
from transformers import glue_compute_metrics, glue_output_modes
"""From transformers/examples/text-classification/run_glue.py"""


def build_compute_metrics_fn(
task_name: str) -> Callable[[EvalPrediction], Dict]:
output_mode = glue_output_modes[task_name]

def compute_metrics_fn(p: EvalPrediction):
if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression":
preds = np.squeeze(p.predictions)
metrics = glue_compute_metrics(task_name, preds, p.label_ids)
return metrics

return compute_metrics_fn


def download_data(task_name, data_dir="./data"):
# Download RTE training data
print("Downloading dataset.")
import urllib
import zipfile
if task_name == "rte":
url = "https://firebasestorage.googleapis.com/v0/b/" \
"mtl-sentence-representations.appspot.com" \
"/o/data%2FRTE.zip?alt=media" \
"&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb"
else:
raise ValueError("Unknown task: {}".format(task_name))
data_file = os.path.join(data_dir, "{}.zip".format(task_name))
if not os.path.exists(data_file):
urllib.request.urlretrieve(url, data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
print("Downloaded data for task {} to {}".format(task_name, data_dir))