Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Extend BERT-based classification with customized layers #4553

Merged
merged 5 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ commands:
- run:
name: Install torch GPU and dependencies
command: |
python -m pip install --progress-bar off torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
python -m pip install --progress-bar off torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
python -m pip install --progress-bar off 'fairscale~=0.4.0'
python -m pip install --progress-bar off pytorch-pretrained-bert
python -m pip install --progress-bar off 'transformers==4.3.3'
Expand All @@ -124,7 +124,7 @@ commands:
name: Install torch CPU and dependencies
command: |
python -m pip install --progress-bar off 'transformers==4.3.3'
python -m pip install --progress-bar off 'torch==1.10.2'
python -m pip install --progress-bar off 'torch==1.11.0'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env

Expand All @@ -134,7 +134,7 @@ commands:
- run:
name: Install torch CPU and dependencies
command: |
python -m pip install --progress-bar off 'torch==1.10.2+cpu' 'torchvision==0.11.3+cpu' 'torchaudio==0.10.2+cpu' -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --progress-bar off 'torch==1.11.0+cpu' 'torchvision==0.12.0+cpu' 'torchaudio==0.11.0+cpu' -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --progress-bar off 'transformers==4.3.3'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
Expand Down
91 changes: 76 additions & 15 deletions parlai/agents/bert_classifier/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

"""
BERT classifier agent uses bert embeddings to make an utterance-level classification.

This implementation allows to customize classifier layers with input arguments.
"""

import os
Expand All @@ -32,6 +34,12 @@
)


LINEAR = "linear"
RELU = "relu"

SUPPORTED_LAYERS = [LINEAR, RELU]


class BertClassifierHistory(History):
"""
Handles tokenization history.
Expand Down Expand Up @@ -72,6 +80,7 @@ def __init__(self, opt, shared=None):
opt["pretrained_path"] = self.pretrained_path
self.add_cls_token = opt.get("add_cls_token", True)
self.sep_last_utt = opt.get("sep_last_utt", False)
self.classifier_layers = opt.get("classifier_layers", None)
super().__init__(opt, shared)

@classmethod
Expand All @@ -90,20 +99,6 @@ def add_cmdline_args(
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser = parser.add_argument_group("BERT Classifier Arguments")
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this option just never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I actually also found it in BertWrapper's add_common_args function, also never used, so I'll remove it from there

"--type-optimization",
type=str,
default="all_encoder_layers",
choices=[
"additional_layers",
"top_layer",
"top4_layers",
"all_encoder_layers",
"all",
],
help="which part of the encoders do we optimize "
"(defaults to all layers)",
)
parser.add_argument(
"--add-cls-token",
type="bool",
Expand All @@ -117,6 +112,13 @@ def add_cmdline_args(
help="separate the last utterance into a different"
"segment with [SEP] token in between",
)
parser.add_argument(
"--classifier-layers",
nargs='+',
type=str,
default=None,
help="list of classifier layers comma-separated with layer's dimension where applicable. For example: linear,64 linear,32 relu",
)
parser.set_defaults(dict_maxexs=0) # skip building dictionary
return parser

Expand All @@ -142,12 +144,71 @@ def upgrade_opt(cls, opt_on_disk):

return opt_on_disk

def _get_layer_parameters(self, prev_dimension, output_dimension):
"""
Parse layer definitions from the input.
"""
layers = []
dimensions = []
no_dimension = -1
for layer in self.classifier_layers:
if ',' in layer:
l, d = layer.split(',')
layers.append(l)
dimensions.append((prev_dimension, int(d)))
prev_dimension = int(d)
else:
layers.append(layer)
dimensions.append(no_dimension)
ind = 0
while (
ind < len(dimensions)
and dimensions[len(dimensions) - ind - 1] == no_dimension
):
ind += 1
if (ind == len(dimensions) and prev_dimension == output_dimension) or (
ind < len(dimensions) and dimensions[ind][1] == output_dimension
):
return layers, dimensions

if ind < len(dimensions):
raise Exception(
f"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}"
)
raise Exception(
f"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}"
)

def _map_layer(self, layer: str, dim=None):
"""
Get torch wrappers for nn layers.
"""
if layer == LINEAR:
return torch.nn.Linear(dim[0], dim[1])
elif layer == RELU:
return torch.nn.ReLU(inplace=False)
raise Exception(
"Unrecognized network layer {}. Available options are: {}".format(
layer, ", ".join(SUPPORTED_LAYERS)
)
)

def build_model(self):
"""
Construct the model.
"""
num_classes = len(self.class_list)
return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes)
bert_model = BertModel.from_pretrained(self.pretrained_path)
if self.classifier_layers is not None:
prev_dimension = bert_model.embeddings.word_embeddings.weight.size(1)
layers, dims = self._get_layer_parameters(
prev_dimension=prev_dimension, output_dimension=num_classes
)
decoders = torch.nn.Sequential()
for l, d in zip(layers, dims):
decoders.append(self._map_layer(l, d))
return BertWrapper(bert_model=bert_model, classifier_layer=decoders)
return BertWrapper(bert_model=bert_model, output_dim=num_classes)

def _set_text_vec(self, *args, **kwargs):
obs = super()._set_text_vec(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions parlai/agents/bert_ranker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ In order to use those agents you need to install pytorch-pretrained-bert (https:

Train a BiEncoder BERT model on ConvAI2:
```bash
parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True
parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True
```

Train a CrossEncoder BERT model on ConvAI2:
```bash
parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True
parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True
```
51 changes: 29 additions & 22 deletions parlai/agents/bert_ranker/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
'installed. Install with:\n `pip install transformers`.'
)


import torch


Expand Down Expand Up @@ -63,20 +62,6 @@ def add_common_args(parser):
'multiple gpus. NOTE This is incompatible'
' with distributed training',
)
parser.add_argument(
'--type-optimization',
type=str,
default='all_encoder_layers',
choices=[
'additional_layers',
'top_layer',
'top4_layers',
'all_encoder_layers',
'all',
],
help='Which part of the encoders do we optimize. '
'(Default: all_encoder_layers.)',
)
parser.add_argument(
'--bert-aggregation',
type=str,
Expand All @@ -97,15 +82,26 @@ def add_common_args(parser):
class BertWrapper(torch.nn.Module):
"""
Adds a optional transformer layer and a linear layer on top of BERT.
Args:
bert_model: pretrained BERT model
output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList)
add_transformer_layer: if additional transformer layer should be added on top of the pretrained model
layer_pulled: which layer should be pulled from pretrained model
aggregation: embeddings aggregation (pooling) strategy. Available options are:
(default)"first" - [CLS] representation,
"mean" - average of all embeddings except CLS,
"max" - max of all embeddings except CLS
"""

def __init__(
self,
bert_model,
output_dim,
add_transformer_layer=False,
layer_pulled=-1,
aggregation="first",
bert_model: BertModel,
output_dim: int = -1,
add_transformer_layer: bool = False,
layer_pulled: int = -1,
aggregation: str = "first",
classifier_layer: torch.nn.Module = None,
):
super(BertWrapper, self).__init__()
self.layer_pulled = layer_pulled
Expand All @@ -123,7 +119,18 @@ def __init__(
hidden_act='gelu',
)
self.additional_transformer_layer = BertLayer(config_for_one_layer)
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
if classifier_layer is None and output_dim == -1:
raise Exception(
"Either output dimention or classifier layers must be specified"
)
elif classifier_layer is None:
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
else:
self.additional_linear_layer = classifier_layer
if output_dim != -1:
print(
"Both classifier layer and output dimension are specified. Output dimension parameter is ignored."
)
self.bert_model = bert_model

def forward(self, token_ids, segment_ids, attention_mask):
Expand Down Expand Up @@ -171,7 +178,7 @@ def forward(self, token_ids, segment_ids, attention_mask):
# Sort of hack to make it work with distributed: this way the pooler layer
# is used for grad computation, even though it does not change anything...
# in practice, it just adds a very (768*768) x (768*batchsize) matmul
result += 0 * torch.sum(output_pooler)
result = result + 0 * torch.sum(output_pooler)
return result


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dec_emb_loss: 0.0151884
dec_hid_loss: 0.662957
dec_hid_loss: 0.662956
dec_self_attn_loss: 497.628
enc_dec_attn_loss: 230.709
enc_emb_loss: 0.0109334
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ enc_emb_loss: 0.00210945
enc_hid_loss: 0.279337
enc_loss: 0.284342
enc_self_attn_loss: 371.567
loss: 11.7943
loss: 11.7944
pred_loss: 6.80161
30 changes: 29 additions & 1 deletion tests/nightly/gpu/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,41 @@ def test_crossencoder(self):
batchsize=2,
learningrate=1e-3,
gradient_clip=1.0,
type_optimization="all_encoder_layers",
text_truncate=8,
label_truncate=8,
)
)
self.assertGreaterEqual(test['accuracy'], 0.8)

def test_bertclassifier(self):
valid, test = testing_utils.train_model(
dict(
task='integration_tests:classifier',
model='bert_classifier/bert_classifier',
num_epochs=2,
batchsize=2,
learningrate=1e-2,
gradient_clip=1.0,
classes=["zero", "one"],
)
)
self.assertGreaterEqual(test['accuracy'], 0.9)

def test_bertclassifier_with_relu(self):
valid, test = testing_utils.train_model(
dict(
task='integration_tests:classifier',
model='bert_classifier/bert_classifier',
num_epochs=2,
batchsize=2,
learningrate=1e-2,
gradient_clip=1.0,
classes=["zero", "one"],
classifier_layers=["linear,64", "linear,2", "relu"],
)
)
self.assertGreaterEqual(test['accuracy'], 0.9)


if __name__ == '__main__':
unittest.main()