Skip to content

Commit

Permalink
Merge pull request #3 from ThilinaRajapakse/master
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
kinoute authored Jan 8, 2020
2 parents d34e2fe + f5a7699 commit e5987c6
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 60 deletions.
54 changes: 53 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,42 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.16.2] - 2020-01-08
### Changed
- Changed print statements to logging.

## [0.16.1] - 2020-01-07
### Added
- Added `wandb_kwargs` to `args` which can be used to specify keyword arguments to `wandb.init()` method.

## [0.16.0] - 2020-01-07
### Added
- Added support for training visualization using the W&B framework.
- Added `save_eval_checkpoints` attribute to `args` which controls whether or not a model checkpoint will be saved with every evaluation.

## [0.15.7] - 2020-01-05
### Added
- Added `**kwargs` for different accuracy measures during multilabel training.

## [0.15.6] - 2020-01-05
### Added
- Added `train_loss` to `training_progress_scores.csv` (which contains the evaluation results of all checkpoints) in the output directory.

## [0.15.5] - 2020-01-05
### Added
- Using `evaluate_during_training` now generates `training_progress_scores.csv` (which contains the evaluation results of all checkpoints) in the output directory.

## [0.15.4] - 2019-12-31
### Fixed
- Fixed bug in `QuestonAnsweringModel` when using `evaluate_during_training`.

## [0.15.3] - 2019-12-31
### Fixed
- Fixed bug in MultiLabelClassificationModel due to `tensorboard_dir` being missing in parameter dictionary.

### Changed
- Renamed `tensorboard_folder` to `tensorboard_dir` for consistency.

## [0.15.2] - 2019-12-28
### Added
- Added `tensorboard_folder` to parameter dictionary which can be used to specify the directory in which the tensorboard files will be stored.
Expand Down Expand Up @@ -82,7 +118,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- This CHANGELOG file to hopefully serve as an evolving example of a
standardized open source project CHANGELOG.

[0.15.1]: https://github.com/ThilinaRajapakse/simpletransformers/compare/268ced8...HEAD
[0.16.2]: https://github.com/ThilinaRajapakse/simpletransformers/compare/d589b75...HEAD

[0.16.1]: https://github.com/ThilinaRajapakse/simpletransformers/compare/d8df83f...d589b75

[0.16.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/1684fff...d8df83f

[0.15.7]: https://github.com/ThilinaRajapakse/simpletransformers/compare/c2f620a...1684fff

[0.15.6]: https://github.com/ThilinaRajapakse/simpletransformers/compare/cd24331...c2f620a

[0.15.5]: https://github.com/ThilinaRajapakse/simpletransformers/compare/38cbea5...cd24331

[0.15.4]: https://github.com/ThilinaRajapakse/simpletransformers/compare/70e2a19...38cbea5

[0.15.3]: https://github.com/ThilinaRajapakse/simpletransformers/compare/a65dc73...70e2a19

[0.15.2]: https://github.com/ThilinaRajapakse/simpletransformers/compare/268ced8...a65dc73

[0.15.1]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2c1e5e0...268ced8

Expand Down
39 changes: 33 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This library is based on the [Transformers](https://github.com/huggingface/trans
* [Minimal Start](#minimal-example)
* [Real Dataset Examples](#real-dataset-examples-2)
* [QuestionAnsweringModel](#questionansweringmodel)
* [Visualization Support](#visualization-support)
* [Experimental Features](#experimental-features)
* [Sliding Window For Long Sequences](#sliding-window-for-long-sequences)
* [Loading Saved Models](#loading-saved-models)
Expand Down Expand Up @@ -208,8 +209,8 @@ print(raw_outputs)
#### Real Dataset Examples

* [Yelp Reviews Dataset - Binary Classification](https://towardsdatascience.com/simple-transformers-introducing-the-easiest-bert-roberta-xlnet-and-xlm-library-58bf8c59b2a3?source=friends_link&sk=40726ceeadf99e1120abc9521a10a55c)
* [AG News Dataset - Multiclass Classification](https://medium.com/swlh/simple-transformers-multi-class-text-classification-with-bert-roberta-xlnet-xlm-and-8b585000ce3a)
* [Toxic Comments Dataset - Multilabel Classification](https://medium.com/@chaturangarajapakshe/multi-label-classification-using-bert-roberta-xlnet-xlm-and-distilbert-with-simple-transformers-b3e0cda12ce5?sk=354e688fe238bfb43e9a575216816219)
* [AG News Dataset - Multiclass Classification](https://medium.com/swlh/simple-transformers-multi-class-text-classification-with-bert-roberta-xlnet-xlm-and-8b585000ce3a?source=friends_link&sk=90e1c97255b65cedf4910a99041d9dfc)
* [Toxic Comments Dataset - Multilabel Classification](https://towardsdatascience.com/multi-label-classification-using-bert-roberta-xlnet-xlm-and-distilbert-with-simple-transformers-b3e0cda12ce5?source=friends_link&sk=354e688fe238bfb43e9a575216816219)


#### ClassificationModel
Expand Down Expand Up @@ -368,7 +369,7 @@ print(predictions)

#### Real Dataset Examples

* [CoNLL Dataset Example](https://medium.com/@chaturangarajapakshe/simple-transformers-named-entity-recognition-with-transformer-models-c04b9242a2a0?sk=e8b98c994173cd5219f01e075727b096)
* [CoNLL Dataset Example](https://towardsdatascience.com/simple-transformers-named-entity-recognition-with-transformer-models-c04b9242a2a0?source=friends_link&sk=e8b98c994173cd5219f01e075727b096)

#### NERModel

Expand Down Expand Up @@ -577,7 +578,7 @@ print(model.predict(to_predict))

#### Real Dataset Examples

* [SQuAD 2.0 - Question Answering](https://medium.com/@chaturangarajapakshe/question-answering-with-bert-xlnet-xlm-and-distilbert-using-simple-transformers-4d8785ee762a?sk=e8e6f9a39f20b5aaf08bbcf8b0a0e1c2)
* [SQuAD 2.0 - Question Answering](https://towardsdatascience.com/question-answering-with-bert-xlnet-xlm-and-distilbert-using-simple-transformers-4d8785ee762a?source=friends_link&sk=e8e6f9a39f20b5aaf08bbcf8b0a0e1c2)

### QuestionAnsweringModel

Expand Down Expand Up @@ -702,6 +703,26 @@ If null_score - best_non_null is greater than the threshold predict null.

---

## Visualization Support

The [Weights & Biases](https://www.wandb.com/) framework is supported for visualizing model training.

To use this, simply set a project name for W&B in the `wandb_project` attribute of the `args` dictionary. This will log all hyperparameter values, training losses, and evaluation metrics to the given project.

```
model = ClassificationModel('roberta', 'roberta-base', args={'wandb_project': 'project-name'})
```

Other keyword arguments can be specified as a dictionay with the `wandb_kwargs` attribute of the `args` dictionary.

```
model = ClassificationModel('roberta', 'roberta-base', args={'wandb_project': 'project-name', 'wandb_kwargs': {'name': 'test-run'}})
```

For a complete example, see [here](https://medium.com/skilai/to-see-is-to-believe-visualizing-the-training-of-machine-learning-models-664ef3fe4f49).

---

## Experimental Features

To use experimental features, import from `simpletransformers.experimental.X`
Expand Down Expand Up @@ -806,8 +827,9 @@ self.args = {
'logging_steps': 50,
'evaluate_during_training': False,
'evaluate_during_training_steps': 2000,
`save_eval_checkpoints`: True
'save_steps': 2000,
'tensorboard_folder': None,
'tensorboard_dir': None,
'overwrite_output_dir': False,
'reprocess_input_data': False,
Expand All @@ -816,6 +838,8 @@ self.args = {
'n_gpu': 1,
'silent': False,
'use_multiprocessing': True,
'wandb_project': None,
}
```

Expand Down Expand Up @@ -869,13 +893,16 @@ Set to True to perform evaluation while training models. Make sure `eval_df` is
#### *evaluate_during_training_steps*
Perform evaluation at every specified number of steps. A checkpoint model and the evaluation results will be saved.

#### *save_eval_checkpoints*
Save a model checkpoint for every evaluation performed.

#### *logging_steps: int*
Log training loss and learning at every specified number of steps.

#### *save_steps: int*
Save a model checkpoint at every specified number of steps.

#### *tensorboard_folder: str*
#### *tensorboard_dir: str*
The directory where Tensorboard events will be stored during training. By default, Tensorboard events will be saved in a subfolder inside `runs/` like `runs/Dec02_09-32-58_36d9e58955b0/`.

#### *overwrite_output_dir: bool*
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="simpletransformers",
version="0.15.2",
version="0.16.2",
author="Thilina Rajapakse",
author_email="chaturangarajapakshe@gmail.com",
description="An easy-to-use wrapper library for the Transformers library.",
Expand All @@ -30,6 +30,7 @@
"scipy",
"scikit-learn",
"seqeval",
"tensorboardx"
"tensorboardx",
"wandb"
],
)
91 changes: 78 additions & 13 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import os
import math
import json
import logging
import random
import warnings

from multiprocessing import cpu_count

import torch
import numpy as np
import pandas as pd

from scipy.stats import pearsonr, mode
from sklearn.metrics import mean_squared_error, matthews_corrcoef, confusion_matrix, label_ranking_average_precision_score
Expand Down Expand Up @@ -53,6 +55,11 @@
from simpletransformers.classification.transformer_models.albert_model import AlbertForSequenceClassification
from simpletransformers.classification.transformer_models.camembert_model import CamembertForSequenceClassification

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

import wandb


class ClassificationModel:
def __init__(self, model_type, model_name, num_labels=None, weight=None, args=None, use_cuda=True, cuda_device=-1):
Expand Down Expand Up @@ -129,7 +136,8 @@ def __init__(self, model_type, model_name, num_labels=None, weight=None, args=No
'save_steps': 2000,
'evaluate_during_training': False,
'evaluate_during_training_steps': 2000,
'tensorboard_folder': None,
'save_eval_checkpoints': True,
'tensorboard_dir': None,

'overwrite_output_dir': False,
'reprocess_input_data': False,
Expand All @@ -141,7 +149,10 @@ def __init__(self, model_type, model_name, num_labels=None, weight=None, args=No

'sliding_window': False,
'tie_value': 1,
'stride': 0.8
'stride': 0.8,

'wandb_project': None,
'wandb_kwargs': None,
}

if not use_cuda:
Expand All @@ -159,6 +170,7 @@ def __init__(self, model_type, model_name, num_labels=None, weight=None, args=No
warnings.warn("use_multiprocessing automatically disabled as CamemBERT fails when using multiprocessing for feature conversion.")
self.args['use_multiprocessing'] = False


def train_model(self, train_df, multi_label=False, output_dir=None, show_running_loss=True, args=None, eval_df=None, **kwargs):
"""
Trains the model using 'train_df'
Expand Down Expand Up @@ -204,16 +216,16 @@ def train_model(self, train_df, multi_label=False, output_dir=None, show_running
if not os.path.exists(output_dir):
os.makedirs(output_dir)

global_step, tr_loss = self.train(train_dataset, output_dir, show_running_loss=show_running_loss, eval_df=eval_df, **kwargs)
global_step, tr_loss = self.train(train_dataset, output_dir, multi_label=multi_label, show_running_loss=show_running_loss, eval_df=eval_df, **kwargs)

model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

print("Training of {} model complete. Saved to {}.".format(self.args["model_type"], output_dir))
logger.info("Training of {} model complete. Saved to {}.".format(self.args["model_type"], output_dir))

def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None, **kwargs):
def train(self, train_dataset, output_dir, multi_label=False, show_running_loss=True, eval_df=None, **kwargs):
"""
Trains the model on train_dataset.
Expand All @@ -225,7 +237,7 @@ def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None,
model = self.model
args = self.args

tb_writer = SummaryWriter(logdir=args["tensorboard_folder"])
tb_writer = SummaryWriter(logdir=args["tensorboard_dir"])
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args["train_batch_size"])

Expand Down Expand Up @@ -262,6 +274,41 @@ def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None,
model.zero_grad()
train_iterator = trange(int(args["num_train_epochs"]), desc="Epoch", disable=args['silent'])
epoch_number = 0
if args['evaluate_during_training']:
extra_metrics = {key: [] for key in kwargs}
if multi_label:
training_progress_scores = {
'global_step': [],
'LRAP': [],
'train_loss': [],
'eval_loss': [],
**extra_metrics
}
else:
if self.model.num_labels == 2:
training_progress_scores = {
'global_step': [],
'tp': [],
'tn': [],
'fp': [],
'fn': [],
'mcc': [],
'train_loss': [],
'eval_loss': [],
**extra_metrics
}
else:
training_progress_scores = {
'global_step': [],
'mcc': [],
'train_loss': [],
'eval_loss': [],
**extra_metrics
}

if args['wandb_project']:
wandb.init(project=args['wandb_project'], config={**args}, **args['wandb_kwargs'])
wandb.watch(self.model)

model.train()
for _ in train_iterator:
Expand All @@ -277,8 +324,10 @@ def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None,
if args['n_gpu'] > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

current_loss = loss.item()

if show_running_loss:
print("\rRunning loss: %f" % loss, end="")
logger.info("\rRunning loss: %f" % loss, end="")

if args["gradient_accumulation_steps"] > 1:
loss = loss / args["gradient_accumulation_steps"]
Expand All @@ -303,6 +352,8 @@ def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None,
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar("loss", (tr_loss - logging_loss)/args["logging_steps"], global_step)
logging_loss = tr_loss
if args['wandb_project']:
wandb.log({'Training loss': current_loss, 'lr': scheduler.get_lr()[0], 'global_step': global_step})

if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
# Save model checkpoint
Expand All @@ -327,15 +378,26 @@ def train(self, train_dataset, output_dir, show_running_loss=True, eval_df=None,
if not os.path.exists(output_dir_current):
os.makedirs(output_dir_current)

model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir_current)
self.tokenizer.save_pretrained(output_dir_current)
if args['save_eval_checkpoints']:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir_current)
self.tokenizer.save_pretrained(output_dir_current)

output_eval_file = os.path.join(output_dir_current, "eval_results.txt")
with open(output_eval_file, "w") as writer:
for key in sorted(results.keys()):
writer.write("{} = {}\n".format(key, str(results[key])))

training_progress_scores['global_step'].append(global_step)
training_progress_scores['train_loss'].append(current_loss)
for key in results:
training_progress_scores[key].append(results[key])
report = pd.DataFrame(training_progress_scores)
report.to_csv(args['output_dir'] + 'training_progress_scores.csv', index=False)

if args['wandb_project']:
wandb.log(self._get_last_metrics(training_progress_scores))

epoch_number += 1
output_dir_current = os.path.join(output_dir, "epoch-{}".format(epoch_number))

Expand Down Expand Up @@ -383,7 +445,7 @@ def eval_model(self, eval_df, multi_label=False, output_dir=None, verbose=False,
self.results.update(result)

if verbose:
print(self.results)
logger.info(self.results)

return result, model_outputs, wrong_preds

Expand Down Expand Up @@ -507,9 +569,9 @@ def load_and_cache_examples(self, examples, evaluate=False, no_cache=False, mult

if os.path.exists(cached_features_file) and not args["reprocess_input_data"] and not no_cache:
features = torch.load(cached_features_file)
print(f"Features loaded from cache at {cached_features_file}")
logger.info(f"Features loaded from cache at {cached_features_file}")
else:
print(f"Converting to features started.")
logger.info(f"Converting to features started.")
features = convert_examples_to_features(
examples,
args["max_seq_length"],
Expand Down Expand Up @@ -717,3 +779,6 @@ def _get_inputs_dict(self, batch):
inputs["token_type_ids"] = batch[2] if self.args["model_type"] in ["bert", "xlnet"] else None

return inputs

def _get_last_metrics(self, metric_values):
return {metric: values[-1] for metric, values in metric_values.items()}
Loading

0 comments on commit e5987c6

Please sign in to comment.