Quick Start Guide
Binary classification based on text data from RuTweetCorp (https://study.mokoron.com/)
negative: 0
positive: 1
When I learn a new approach, there are usually articles and tutorials that are extremely detailed, from primary data processing to building learning curves. I always wanted to quickly understand the essence of the approach and immediately start using the existing developments, and not painfully deal with a sheet of someone else's code. Therefore, I decided to make a solution as simple and transparent as possible, which will not be overloaded with unnecessary code, which can be easily and quickly figured out.
I won't write anything about BERT - there are a lot of great articles about it, so we'll just use it as a black box.
The cleared data of the Russian-language twitter longer than 100 characters is used.
RuTweetCorp (https://study.mokoron.com/)
The CustomDataset class is required for use with the transformers library. Inherits from the Dataset class. It defines 3 required functions: init, len, getitem. Main purpose - returns tokenized data in the desired format.
When the classifier is initialized, the following actions are performed:
- The model and tokenizer are downloaded from the huggingface repository;
- The presence of a target device for computing is determined;
- The dimension of embeddings is determined;
- The number of classes is set;
- The number of epochs for training is set.
To train BERT, you need to initialize several auxiliary elements:
- DataLoader: needed to create batches;
- Optimizer: gradient descent optimizer;
- Scheduler: the scheduler is needed to configure the optimizer parameters;
- Loss: loss function, we calculate the model error from it.
- Training for one epoch is described in fit method.
- The data in the cycle is generated by batches using the DataLoader;
- Batch is fed into the model;
- At the output, we get the probability distribution by classes and the error value;
- Take a step on all auxiliary functions:
- loss.backward: backward propagation of the error;
- clip_grad_norm: clip gradients to prevent gradient explosion;
- optimizer.step: optimizer step;
- scheduler.step: scheduler step;
- optimizer.zero_grad: zeroing gradients.
- Checking on the validation set is carried out using the eval method. At the same time, we use the torch.no_grad method to prevent training on the validation set.
- For multi-epoch training, the train method is used, in which the fit and eval methods are called sequentially.
To predict a class for a new text, the predict method is used, which makes sense to call only after training the model.
The method works like this:
- Input text is tokenized;
- Tokenized text is fed into the model;
- At the output, we get the probabilities of the classes;
- Returning the label of the most likely class.
I wanted it as simple as possible, but still it turned out somehow voluminous. Please understand and forgive.