Skip to content

Latest commit

 

History

History
69 lines (57 loc) · 3.64 KB

readme_en.md

File metadata and controls

69 lines (57 loc) · 3.64 KB

[RU|EN]

Open In Colab

BERT for classification problem

Quick Start Guide
Binary classification based on text data from RuTweetCorp (https://study.mokoron.com/)
negative: 0
positive: 1

What for?

When I learn a new approach, there are usually articles and tutorials that are extremely detailed, from primary data processing to building learning curves. I always wanted to quickly understand the essence of the approach and immediately start using the existing developments, and not painfully deal with a sheet of someone else's code. Therefore, I decided to make a solution as simple and transparent as possible, which will not be overloaded with unnecessary code, which can be easily and quickly figured out.
I won't write anything about BERT - there are a lot of great articles about it, so we'll just use it as a black box.

Structure

Training data

The cleared data of the Russian-language twitter longer than 100 characters is used.
RuTweetCorp (https://study.mokoron.com/)

CustomDataset

The CustomDataset class is required for use with the transformers library. Inherits from the Dataset class. It defines 3 required functions: init, len, getitem. Main purpose - returns tokenized data in the desired format.

Initialize

When the classifier is initialized, the following actions are performed:

  • The model and tokenizer are downloaded from the huggingface repository;
  • The presence of a target device for computing is determined;
  • The dimension of embeddings is determined;
  • The number of classes is set;
  • The number of epochs for training is set.

Preparation

To train BERT, you need to initialize several auxiliary elements:

  • DataLoader: needed to create batches;
  • Optimizer: gradient descent optimizer;
  • Scheduler: the scheduler is needed to configure the optimizer parameters;
  • Loss: loss function, we calculate the model error from it.

Train

  • Training for one epoch is described in fit method.
    • The data in the cycle is generated by batches using the DataLoader;
    • Batch is fed into the model;
    • At the output, we get the probability distribution by classes and the error value;
    • Take a step on all auxiliary functions:
      • loss.backward: backward propagation of the error;
      • clip_grad_norm: clip gradients to prevent gradient explosion;
      • optimizer.step: optimizer step;
      • scheduler.step: scheduler step;
      • optimizer.zero_grad: zeroing gradients.
  • Checking on the validation set is carried out using the eval method. At the same time, we use the torch.no_grad method to prevent training on the validation set.
  • For multi-epoch training, the train method is used, in which the fit and eval methods are called sequentially.

Inference

To predict a class for a new text, the predict method is used, which makes sense to call only after training the model.
The method works like this:

  • Input text is tokenized;
  • Tokenized text is fed into the model;
  • At the output, we get the probabilities of the classes;
  • Returning the label of the most likely class.

Conclusion

I wanted it as simple as possible, but still it turned out somehow voluminous. Please understand and forgive.