Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Semantic Similarity with BERT example #205

Merged
merged 11 commits into from
Aug 31, 2020

Conversation

MohamadMerchant
Copy link
Contributor

No description provided.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! This will make a great new example 👍

with strategy.scope():
model = build_model()
print(f"Strategy: {strategy}")
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try-except should not be necessary. Why would it fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it will fail while training on CPU but it doesn't. So I have removed try-except block as it doesn't seem necessary.

Description: Natural Language Inference by Fine tuning BERT model on SNLI Corpus.
"""
"""
## **Introduction**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the bold markers, it's already a title

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed bold markers.

print(f"Total test samples: {valid_df.shape[0]}")

"""
Dataset Info:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please format this block as markdown

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatted.

## Preprocessing
"""

# we have some nan in our train data, we will simply drop them
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure that all comments start with a capital letter and end with a period

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with all the comments.


def build_model():
"""
model inputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring format: start with a one-line description. Then "Arguments:" section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to comment descriptions.

"""
## Train the Model
"""
h = model.fit_generator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use full-spelled argument names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all argument/variable names as full-spelled

"""


class DataGenerator(tf.keras.utils.Sequence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a more descriptive name than "DataGenerator"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used BertSemanticDataGenerator.


"""

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Check results on some example sentence pairs"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added line.

shape=(MAXLEN,), dtype=tf.int32, name="tt_ids"
)
# Loading pretrained BERT model
bertModel = TFBertModel.from_pretrained(MODEL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lower case camel case for variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used lower_case for all variable names.

LR = 3e-5
np.random.seed(42)
# we will use base-base-uncased pretrained model
MODEL = "bert-base-uncased"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"MODEL" is too generic, this is "pretrained_model_name"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed "MODEL" from configurations and added description comment while loading.

@MohamadMerchant
Copy link
Contributor Author

Thanks for the review @fchollet . I have made all the necessary changes.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

"""

"""
## Setup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mention that transformers should be installed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned to install transformers in Setup description.

import pandas as pd
import tensorflow as tf
import transformers
from transformers import BertTokenizer, TFBertModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this import line, you can just access these objects from transformers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this line.

batch_size = 32
epochs = 4
learning_rate = 3e-5
np.random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear what this seed is meant to control. It will have no effect over Keras reproducibility. If it is specifically for the data shuffling, prefer passing a seed argument to your Sequence object and use it locally there (via a RandomState object). In general avoid global seeding since it has side effects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is mentioned specifically for data shuffling.
If it's not a problem, I have changed and used static value as seed in RandomState object in Sequence.

"""
Distribution of our training and validation targets.
"""
print(train_df.similarity.value_counts())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use string formatting to print a sentence that makes it clear this is training (resp. validation) data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added string to separate both training and validation data.

"""

# We have some nan in our train data, we will simply drop them.
print(train_df.isnull().sum())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use string formatting to print a sentence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added string.

"""
history = model.fit_generator(
train_data,
steps_per_epoch=len(train_data) // batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since a Sequence has a __len__, this argument should not be necessary (same for validation_steps)

Copy link
Contributor Author

@MohamadMerchant MohamadMerchant Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have experimented with steps_per_epoch but removing it increases training time drastically.

experiment 1: With steps_per_epoch whole training data took around 1 hour for training 4 epochs.
Train Samples : 549661
Train Steps: 536 (549661 // 32) // 32.

experiment 2 : Without steps_per_epoch 100k training samples took around 1 hour to finish only 1 epoch.
Train Samples : 100000
Train Steps: 1560 (100000 // 64)

So I have used it for faster training as without steps_per_epoch was taking a lot of time and increasing batch_size more than 64 causes OOM errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(train_data) is already num_sentences // batch_size.

Here you are "training faster" simply because you are training on 32x fewer batches per epoch (one batch out of 32).

The proper procedure would be to train on all batches at each epoch (no steps_per_epoch arg), and train for fewer epochs.

Since you are only training with 4 epochs, it sounds like your model is too large for your dataset. Consider adding more dropout. Is bert_model supposed to be trainable?

Copy link
Contributor Author

@MohamadMerchant MohamadMerchant Aug 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I got the point.

I have made changes and trained a model on 100k train samples (with comment description) and no steps_per_epoch argument on two epochs with higher dropout(0.3) and it outperformed all the earlier models I used.

I have trained it with setting bert_model trainable to False, and also by stacking some dense layers, but in both the cases the model didn't improve and was stuck at ~35% accuracy. I think huggingface has more large models pretrained on MNLI such as roberta-large-mnli but those are too large models for our dataset.

"""
## Inference on custom sentences
"""
labels = ["CONTRADICTION", "ENTAILMENT", "NEUTRAL"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a kind of global constant, so you can move it higher up in the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shifted labels to "Configurations" section with a comment description.

@MohamadMerchant
Copy link
Contributor Author

I have made a commit of all the changes except steps_per_epoch as mentioned in the comment. Please let me know what we can do about it.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. It looks good! 👍

The main remaining issue is the training process. Please try feature extraction first, then see how much gains come from fine-tuning.

Note that I have pushed some copyedits. Please pull them first before editing.


def on_epoch_end(self):
# Shuffle indexes after each epoch if shuffle is set to True.
self.indexes = np.arange(len(self.sentence1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move this line inside the if

Copy link
Contributor Author

@MohamadMerchant MohamadMerchant Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this line inside if will result in error message when shuffle=False as the validation and test generator objects cannot find self.indexes in this line in __getitem__
indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]

Copy link
Contributor

@fchollet fchollet Aug 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is used in __getitem__, you should define self.indexes in __init__.

return len(self.sentence1) // self.batch_size

def __getitem__(self, idx):
# Generates batch of data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can be more precise: "Retrieves the batch of index idx."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
## Train the Model
"""
history = model.fit_generator(train_data, validation_data=valid_data, epochs=epochs,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A big problem here is that you're not doing fine-tuning properly.

To do fine-tuning, you need 2 steps:

  1. Freeze bert_model and fit the model. This is called "feature extraction": you just reuse the pretrained features without modifying them.
  2. unfreeze bert_model and fit again (with a very small learning rate). This is called fine-tuning.

Step 2 may not be necessary -- for many applications, feature extraction is sufficient.

Also, please note that:

  • You can pass a generator to fit. fit_generator is legacy.
  • If you fit from a generator, you should generally use data multiprocessing to get better GPU utilization. See the docs for fit().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used only "feature extraction" and trained by freezing bert_model that gave me accuracy around 37%.
After some tuning, I experimented stacking a classification head below freezed bert_model :

  1. stacking 2 dense layers with dropout gave accuracy around 45%.
  2. stacking 1 Bi-LSTM layer with dropout gave accuracy around 69%.

Accuracy of unfreeze bert_model (our current model) is around 87%.

I have gone through this paper yesterday as they have experimented with different datasets and models, I found this result related to our task.

Frozen up to  MNLI     MNLI-mm
0th           84.2      82.0
9th           82.0      82.4
12th          56.4      57.1

I have changed fit_generator to fit and also used multiprocessing for better GPU utilization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense that you get better results by training the entire BERT model, but in this case it makes little sense to load a pretrained model (since you lose its representations almost immediately due to the presence of untrained layers in the model).

Either load the BERT model from scratch (no pretrained weights), or do actual fine-tuning (the two steps I described above). The first option is the simplest.

Copy link
Contributor Author

@MohamadMerchant MohamadMerchant Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your point for the first time but I was little bit concerned about the accuracy. Now, the model is trained on both the steps "feature extraction" and "fine-tuning" (unfreeze bert_model) that gave quite good results. Thank you for your great advice.
I have also made a change in data generator and used batch_encode_plus that has speed up the training.

batch_size=batch_size,
shuffle=False,
)
model.evaluate_generator(test_data, verbose=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, just evaluate

validation_data=valid_data,
epochs=epochs,
use_multiprocessing=True,
workers=-1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 4 or 8 causes nondeterministic deadlocks and doesn't start the training.
(-1 uses 8 workers in colab.)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update, it looks good to me!

I've made some copyedits, please pull them first.

Everything looks great, so you can add the generated files, now. Thanks for the great contribution!

@MohamadMerchant
Copy link
Contributor Author

Thank you for the great review, It was great contributing to Keras.

@fchollet fchollet merged commit e59be23 into keras-team:master Aug 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants