-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Tune] Transformer blog example (#9789)
Co-authored-by: Kai Fricke <kai@anyscale.com>
- Loading branch information
Showing
15 changed files
with
443 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
:orphan: | ||
|
||
pbt_transformers_example | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/pbt_transformers.py | ||
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/trainer.py | ||
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/utils.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
262 changes: 262 additions & 0 deletions
262
python/ray/tune/examples/pbt_transformers/pbt_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
import os | ||
|
||
import ray | ||
from ray.tune import CLIReporter | ||
from ray.tune.schedulers import PopulationBasedTraining | ||
|
||
from ray import tune | ||
from ray.tune.examples.pbt_transformers.utils import \ | ||
build_compute_metrics_fn, download_data | ||
from ray.tune.examples.pbt_transformers import trainer | ||
|
||
from transformers import (AutoConfig, AutoModelForSequenceClassification, | ||
AutoTokenizer, GlueDataset, GlueDataTrainingArguments | ||
as DataTrainingArguments, glue_tasks_num_labels, | ||
Trainer, TrainingArguments) | ||
|
||
|
||
def get_trainer(model_name_or_path, | ||
train_dataset, | ||
eval_dataset, | ||
task_name, | ||
training_args, | ||
wandb_args=None): | ||
try: | ||
num_labels = glue_tasks_num_labels[task_name] | ||
except KeyError: | ||
raise ValueError("Task not found: %s" % (task_name)) | ||
|
||
config = AutoConfig.from_pretrained( | ||
model_name_or_path, | ||
num_labels=num_labels, | ||
finetuning_task=task_name, | ||
) | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained( | ||
model_name_or_path, | ||
config=config, | ||
) | ||
tune_trainer = trainer.TuneTransformerTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
compute_metrics=build_compute_metrics_fn(task_name), | ||
wandb_args=wandb_args) | ||
|
||
return tune_trainer | ||
|
||
|
||
def recover_checkpoint(tune_checkpoint_dir, model_name=None): | ||
if tune_checkpoint_dir is None or len(tune_checkpoint_dir) == 0: | ||
return model_name | ||
# Get subdirectory used for Huggingface. | ||
subdirs = [ | ||
os.path.join(tune_checkpoint_dir, name) | ||
for name in os.listdir(tune_checkpoint_dir) | ||
if os.path.isdir(os.path.join(tune_checkpoint_dir, name)) | ||
] | ||
# There should only be 1 subdir. | ||
assert len(subdirs) == 1, subdirs | ||
return subdirs[0] | ||
|
||
|
||
# __train_begin__ | ||
def train_transformer(config, checkpoint_dir=None): | ||
data_args = DataTrainingArguments( | ||
task_name=config["task_name"], data_dir=config["data_dir"]) | ||
tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) | ||
train_dataset = GlueDataset( | ||
data_args, | ||
tokenizer=tokenizer, | ||
mode="train", | ||
cache_dir=config["data_dir"]) | ||
eval_dataset = GlueDataset( | ||
data_args, | ||
tokenizer=tokenizer, | ||
mode="dev", | ||
cache_dir=config["data_dir"]) | ||
eval_dataset = eval_dataset[:len(eval_dataset) // 2] | ||
training_args = TrainingArguments( | ||
output_dir=tune.get_trial_dir(), | ||
learning_rate=config["learning_rate"], | ||
do_train=True, | ||
do_eval=True, | ||
evaluate_during_training=True, | ||
eval_steps=(len(train_dataset) // config["per_gpu_train_batch_size"]) + | ||
1, | ||
# We explicitly set save to 0, and do saving in evaluate instead | ||
save_steps=0, | ||
num_train_epochs=config["num_epochs"], | ||
max_steps=config["max_steps"], | ||
per_device_train_batch_size=config["per_gpu_train_batch_size"], | ||
per_device_eval_batch_size=config["per_gpu_val_batch_size"], | ||
warmup_steps=0, | ||
weight_decay=config["weight_decay"], | ||
logging_dir="./logs", | ||
) | ||
|
||
# Arguments for W&B. | ||
name = tune.get_trial_name() | ||
wandb_args = { | ||
"project_name": "transformers_pbt", | ||
"watch": "false", # Either set to gradient, false, or all | ||
"run_name": name, | ||
} | ||
|
||
tune_trainer = get_trainer( | ||
recover_checkpoint(checkpoint_dir, config["model_name"]), | ||
train_dataset, | ||
eval_dataset, | ||
config["task_name"], | ||
training_args, | ||
wandb_args=wandb_args) | ||
tune_trainer.train( | ||
recover_checkpoint(checkpoint_dir, config["model_name"])) | ||
|
||
|
||
# __train_end__ | ||
|
||
|
||
# __tune_begin__ | ||
def tune_transformer(num_samples=8, | ||
gpus_per_trial=0, | ||
smoke_test=False, | ||
ray_address=None): | ||
ray.init(ray_address, log_to_driver=False) | ||
data_dir_name = "./data" if not smoke_test else "./test_data" | ||
data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name)) | ||
if not os.path.exists(data_dir): | ||
os.mkdir(data_dir, 0o755) | ||
|
||
# Change these as needed. | ||
model_name = "bert-base-uncased" if not smoke_test \ | ||
else "distilbert-base-uncased" | ||
task_name = "rte" | ||
|
||
task_data_dir = os.path.join(data_dir, task_name.upper()) | ||
|
||
# Download and cache tokenizer, model, and features | ||
print("Downloading and caching Tokenizer") | ||
|
||
# Triggers tokenizer download to cache | ||
AutoTokenizer.from_pretrained(model_name) | ||
print("Downloading and caching pre-trained model") | ||
|
||
# Triggers model download to cache | ||
AutoModelForSequenceClassification.from_pretrained(model_name) | ||
|
||
# Download data. | ||
download_data(task_name, data_dir) | ||
|
||
config = { | ||
"model_name": model_name, | ||
"task_name": task_name, | ||
"data_dir": task_data_dir, | ||
"per_gpu_val_batch_size": 32, | ||
"per_gpu_train_batch_size": tune.choice([16, 32, 64]), | ||
"learning_rate": tune.uniform(1e-5, 5e-5), | ||
"weight_decay": tune.uniform(0.0, 0.3), | ||
"num_epochs": tune.choice([2, 3, 4, 5]), | ||
"max_steps": 1 if smoke_test else -1, # Used for smoke test. | ||
} | ||
|
||
scheduler = PopulationBasedTraining( | ||
time_attr="training_iteration", | ||
metric="eval_acc", | ||
mode="max", | ||
perturbation_interval=1, | ||
hyperparam_mutations={ | ||
"weight_decay": lambda: tune.uniform(0.0, 0.3).func(None), | ||
"learning_rate": lambda: tune.uniform(1e-5, 5e-5).func(None), | ||
"per_gpu_train_batch_size": [16, 32, 64], | ||
}) | ||
|
||
reporter = CLIReporter( | ||
parameter_columns={ | ||
"weight_decay": "w_decay", | ||
"learning_rate": "lr", | ||
"per_gpu_train_batch_size": "train_bs/gpu", | ||
"num_epochs": "num_epochs" | ||
}, | ||
metric_columns=[ | ||
"eval_acc", "eval_loss", "epoch", "training_iteration" | ||
]) | ||
|
||
analysis = tune.run( | ||
train_transformer, | ||
resources_per_trial={ | ||
"cpu": 1, | ||
"gpu": gpus_per_trial | ||
}, | ||
config=config, | ||
num_samples=num_samples, | ||
scheduler=scheduler, | ||
keep_checkpoints_num=3, | ||
checkpoint_score_attr="training_iteration", | ||
stop={"training_iteration": 1} if smoke_test else None, | ||
progress_reporter=reporter, | ||
local_dir="~/ray_results/", | ||
name="tune_transformer_pbt") | ||
|
||
if not smoke_test: | ||
test_best_model(analysis, config["model_name"], config["task_name"], | ||
config["data_dir"]) | ||
|
||
|
||
# __tune_end__ | ||
|
||
|
||
def test_best_model(analysis, model_name, task_name, data_dir): | ||
data_args = DataTrainingArguments(task_name=task_name, data_dir=data_dir) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
best_config = analysis.get_best_config(metric="eval_acc", mode="max") | ||
print(best_config) | ||
best_checkpoint = recover_checkpoint( | ||
analysis.get_best_trial(metric="eval_acc", | ||
mode="max").checkpoint.value) | ||
print(best_checkpoint) | ||
best_model = AutoModelForSequenceClassification.from_pretrained( | ||
best_checkpoint).to("cuda") | ||
|
||
test_args = TrainingArguments(output_dir="./best_model_results", ) | ||
test_dataset = GlueDataset( | ||
data_args, tokenizer=tokenizer, mode="dev", cache_dir=data_dir) | ||
test_dataset = test_dataset[len(test_dataset) // 2:] | ||
|
||
test_trainer = Trainer( | ||
best_model, | ||
test_args, | ||
compute_metrics=build_compute_metrics_fn(task_name)) | ||
|
||
metrics = test_trainer.evaluate(test_dataset) | ||
print(metrics) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
parser.add_argument( | ||
"--ray-address", | ||
type=str, | ||
default=None, | ||
help="Address to use for Ray. " | ||
"Use \"auto\" for cluster. " | ||
"Defaults to None for local.") | ||
args, _ = parser.parse_known_args() | ||
|
||
if args.smoke_test: | ||
tune_transformer( | ||
num_samples=1, | ||
gpus_per_trial=0, | ||
smoke_test=True, | ||
ray_address=args.ray_address) | ||
else: | ||
# You can change the number of GPUs here: | ||
tune_transformer( | ||
num_samples=8, gpus_per_trial=1, ray_address=args.ray_address) |
10 changes: 10 additions & 0 deletions
10
python/ray/tune/examples/pbt_transformers/test_data/RTE/dev.tsv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
index sentence1 sentence2 label | ||
0 Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation. Christopher Reeve had an accident. not_entailment | ||
1 Yet, we now are discovering that antibiotics are losing their effectiveness against illness. Disease-causing bacteria are mutating faster than we can come up with new antibiotics to fight the new variations. Bacteria is winning the war against antibiotics. entailment | ||
2 Cairo is now home to some 15 million people - a burgeoning population that produces approximately 10,000 tonnes of rubbish per day, putting an enormous strain on public services. In the past 10 years, the government has tried hard to encourage private investment in the refuse sector, but some estimate 4,000 tonnes of waste is left behind every day, festering in the heat as it waits for someone to clear it up. It is often the people in the poorest neighbourhoods that are worst affected. But in some areas they are fighting back. In Shubra, one of the northern districts of the city, the residents have taken to the streets armed with dustpans and brushes to clean up public areas which have been used as public dumps. 15 million tonnes of rubbish are produced daily in Cairo. not_entailment | ||
3 The Amish community in Pennsylvania, which numbers about 55,000, lives an agrarian lifestyle, shunning technological advances like electricity and automobiles. And many say their insular lifestyle gives them a sense that they are protected from the violence of American society. But as residents gathered near the school, some wearing traditional garb and arriving in horse-drawn buggies, they said that sense of safety had been shattered. "If someone snaps and wants to do something stupid, there's no distance that's going to stop them," said Jake King, 56, an Amish lantern maker who knew several families whose children had been shot. Pennsylvania has the biggest Amish community in the U.S. not_entailment | ||
4 Security forces were on high alert after an election campaign in which more than 1,000 people, including seven election candidates, have been killed. Security forces were on high alert after a campaign marred by violence. entailment | ||
5 In 1979, the leaders signed the Egypt-Israel peace treaty on the White House lawn. Both President Begin and Sadat received the Nobel Peace Prize for their work. The two nations have enjoyed peaceful relations to this day. The Israel-Egypt Peace Agreement was signed in 1979. entailment | ||
6 singer and actress Britney Spears, 24, has filled papers in Los Angeles County Superior Court to divorce her husband Kevin Federline, 28. A spokeswoman for the court, Kathy Roberts stated that the papers cited irreconcilable differences" as the reason for the divorce and have, according to the courts, been legally separated as of Monday, November 6, the same day that Spears appeared on Late Night with David Letterman. Spears is to divorce from Kevin Federline. entailment | ||
7 Following the successful bid to bring the 2010 Ryder Cup to Wales, the Wales Tourist Board has wasted little time in commissioning work to ensure that the benefits accruing from the event are felt throughout the country. Wales to host 2010 Ryder Cup. entailment | ||
8 Steve Jobs was attacked by Sculley and other Apple executives for not delivering enough hot new products and resigned from the company a few weeks later. Steve Jobs worked for Apple. entailment |
10 changes: 10 additions & 0 deletions
10
python/ray/tune/examples/pbt_transformers/test_data/RTE/train.tsv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
index sentence1 sentence2 label | ||
0 No Weapons of Mass Destruction Found in Iraq Yet. Weapons of Mass Destruction Found in Iraq. not_entailment | ||
1 A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI. Pope Benedict XVI is the new leader of the Roman Catholic Church. entailment | ||
2 Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients. Herceptin can be used to treat breast cancer. entailment | ||
3 Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment. The previous name of Ho Chi Minh City was Saigon. entailment | ||
4 A man is due in court later charged with the murder 26 years ago of a teenager whose case was the first to be featured on BBC One's Crimewatch. Colette Aram, 16, was walking to her boyfriend's house in Keyworth, Nottinghamshire, on 30 October 1983 when she disappeared. Her body was later found in a field close to her home. Paul Stewart Hutchinson, 50, has been charged with murder and is due before Nottingham magistrates later. Paul Stewart Hutchinson is accused of having stabbed a girl. not_entailment | ||
5 Britain said, Friday, that it has barred cleric, Omar Bakri, from returning to the country from Lebanon, where he was released by police after being detained for 24 hours. Bakri was briefly detained, but was released. entailment | ||
6 Nearly 4 million children who have at least one parent who entered the U.S. illegally were born in the United States and are U.S. citizens as a result, according to the study conducted by the Pew Hispanic Center. That's about three quarters of the estimated 5.5 million children of illegal immigrants inside the United States, according to the study. About 1.8 million children of undocumented immigrants live in poverty, the study found. Three quarters of U.S. illegal immigrants have children. not_entailment | ||
7 Like the United States, U.N. officials are also dismayed that Aristide killed a conference called by Prime Minister Robert Malval in Port-au-Prince in hopes of bringing all the feuding parties together. Aristide had Prime Minister Robert Malval murdered in Port-au-Prince. not_entailment | ||
8 WASHINGTON -- A newly declassified narrative of the Bush administration's advice to the CIA on harsh interrogations shows that the small group of Justice Department lawyers who wrote memos authorizing controversial interrogation techniques were operating not on their own but with direction from top administration officials, including then-Vice President Dick Cheney and national security adviser Condoleezza Rice. At the same time, the narrative suggests that then-Defense Secretary Donald H. Rumsfeld and then-Secretary of State Colin Powell were largely left out of the decision-making process. Dick Cheney was the Vice President of Bush. entailment |
Empty file.
Oops, something went wrong.