Skip to content

Text classification repository built with Torch, featuring training tricks, acceleration methods, and model optimization techniques like distillation, compression, and pruning. Supports single-label and multi-label training with customizable configurations. 基于Torch的文本分类仓库,包含训练技巧、加速方法以及模型优化技术,如蒸馏、压缩和剪枝。支持单标签和多标签训练,并提供可自定义的配置选项。

License

Notifications You must be signed in to change notification settings

stanleylsx/text_classifier_torch

Repository files navigation

English | 中文说明



Text Classification By Torch

Authour License Apache 2.0 python_version torch_version

Introduction

This repository is built for text classification tasks based on Torch, incorporating various training tricks, acceleration methods, distillation, compression, and pruning techniques. It supports both single-label and multi-label training. More detailed and customizable training configurations can be explored and adjusted in the config.py file.

While text classification problems can be effectively addressed by Large Language Models (LLMs) in the LLM era, traditional classification models are still needed in specific scenarios, such as high-concurrency services or low-resource devices. This project implements most of the classification training techniques used before the LLM era. Additionally, it organizes relevant models and training-related techniques along with their corresponding research papers at the end of the documentation, providing users with a convenient reference.

Updates

Date Version Description
2023-12-01 v5.4.0 Added support for multi-label text classification and introduced a multi-label classification loss function.
2023-05-31 v5.3.0 Modified the data post-processing logic to occur after each training batch.
2022-12-12 v5.2.0 Integrated TextPruner to add model pruning methods, including pruning of vocabularies and intermediate Transformer layers.
2022-11-10 v5.1.1 Added key adversarial training methods: FreeLB and AWP; refactored adversarial utility classes and introduced the NoisyTune method.
2022-10-21 v5.0.0 Integrated TextBrewer distillation toolkit into the framework, added self-distillation support, and enabled distillation for smaller pre-trained models.
2022-09-23 v4.0.2 Added EMA (Exponential Moving Average) for model smoothing.
2022-09-01 v4.0.1 Enabled distillation of K-fold models into a single model.
2022-08-31 v4.0.0 Introduced K-folds to better utilize validation datasets; merged distillation training into the train logic and unified loss function management.
2022-06-14 v3.9.2 Added support for the DeBERTaV3 model.
2022-05-23 v3.9.0 Added label smoothing to mitigate inconsistencies caused by different annotators.
2022-04-24 v3.6.0 Introduced multisample dropout.
2022-04-21 v3.5.0 Added SWA (Stochastic Weight Averaging) for training
2022-03-31 v3.4.0 Added FastText, TextRCNN, MiniLM models, configurable optimizers, and probability outputs.
2022-03-15 v3.3.0 Integrated Transformer, XLNet, ALBert, RoBerta, and Electra models.
2022-03-03 v3.1.0 Introduced Warmup for training, options for initializing non-finetuned model parameters, and support for fp16 mixed precision training.
2021-12-20 v3.0.0 Added various training tricks: R-Drop, FGM, PGD, and Focal Loss.
2021-12-20 v2.2.0 Provided methods to convert Torch models to ONNX format.
2021-09-23 v2.1.0 Automatically generates validation datasets if not provided; added test dataset evaluation; enabled configurable pre-trained models.
2021-08-27 v2.0.0 Provided two different distillation methods with detailed references to research papers.
2021-08-10 v1.0.0 Initial release of the repository.

Requirement

Key environments:

  • python:3.10+
  • torch:2.4.1+
  • Additional dependencies: see requirements.txt

Usage

The project supports multiple modes as follows:

Mode Detail
train_classifier Train a classification model
interactive_predict Interactive prediction mode
test Evaluate on a test dataset
convert_onnx Save Torch model as ONNX file
show_model_info Print model parameters

It also supports methods for training or distilling and pruning models:

Mode Detail
finetune Fine-tune a pre-trained model
train_small_model Train a small model independently
distillation Model distillation
prune Model pruning

Supported Models:

Type Detail
Pretrained Models Bert、DistilBert、RoBerta、ALBert、XLNet、Electra、MiniLM、DeBertaV3、XLM-RoBERTa
Traditional Models FastText、TextCNN、TextRNN、TextRCNN、Transformer

Models are configured in config.py, with f_model_type specifying pre-trained models and s_model_type for small models. Use the stage parameter to perform operations on specific models.

Train

hree sample datasets are provided for quick start:

DataSet Task
example1 Binary classification (Sentiment analysis)
example2 Multi-class classification (News categorization)
example3 Multi-label classification

Replace the config.py file in the dataset directory with the project's config.py and run main.py to start training. Intermediate validation results will be displayed during training. Incremental training is supported.
train

Distill

The project provides various distillation methods to train student models while preserving classification performance. Configure distillation parameters in distill_configure within config.py.

  • Cross-Model Distillation

distillation1 Use example_datasets2/config.py for a simple configuration that distills a trained Bert model into a TextCNN model. Modify the stage as follows and run main.py to start distillation:

stage = 'distillation'

The distilled model also supports testing on the test dataset (set mode = 'test') and making predictions (set mode = 'interactive_predict').

  • Self-Distillation

distillation2

Self-distillation of pre-trained models can reduce the number of Transformer blocks in the model. For example, a 12-layer BERT-Base model can be distilled into a model with only 3 Transformer blocks. This distillation logic is implemented by integrating TextBrewer. Continue using the config.py file under the example_datasets2 directory. Modify the stage to distillation and update the distill_configure with the following configuration. Then, run main.py to start the distillation process.

distill_configure = {
    'self_distillation': True,
    'distillation_method': 'mse',
    'teacher_model_type': 'Bert',
    'student_model_type': 'Bert',
    'checkpoints_dir': 'checkpoints/example2_distillation_1',
    'epoch': 100,
    'batch_size': 32,
    'learning_rate': 0.0001,
    'print_per_batch': 50,
    'is_early_stop': True,
    'patient': 2,
    'alpha': 0.1,
    'temperature': 4,
    'student_model_name': 'distillation_model.bin',
    'teacher_model_name': 'torch.bin',
    'distill_mlm_config': {
        'attention_probs_dropout_prob': 0.1,
        'hidden_act': 'gelu',
        'hidden_dropout_prob': 0.1,
        'hidden_size': 768,
        'initializer_range': 0.02,
        'intermediate_size': 3072,
        'max_position_embeddings': 512,
        'num_attention_heads': 12,
        'num_hidden_layers': 3,
        'type_vocab_size': 2,
        'vocab_size': 21128
    },
    'intermediate_matches': [
        {"layer_T": 0, "layer_S": 0, "feature": "hidden", "loss": "hidden_mse", "weight": 1},
        {"layer_T": 4, "layer_S": 1, "feature": "hidden", "loss": "hidden_mse", "weight": 1},
        {"layer_T": 8, "layer_S": 2, "feature": "hidden", "loss": "hidden_mse", "weight": 1},
        {"layer_T": 12, "layer_S": 3, "feature": "hidden", "loss": "hidden_mse", "weight": 1}
    ]
}

The distill_mlm_config and intermediate_matches can be configured based on the examples provided by TextBrewer. They support distillation of the hidden states and multi-head attention within Transformer blocks.

  • Multi-Model Fusion Distillation
    If you have trained K models using K-Folds and wish to merge five models into one, this project provides a method for multi-model fusion distillation. The approach involves performing cross-model distillation sequentially from the K models to a single target model.

Interactive Predict

After training the model, set mode = interactive_predict to quickly test the model's performance: predict

  • When K-Folds is enabled, K models will be retained locally. During prediction, the logits from the K models for the same sample will be averaged (model fusion), and the corresponding classification will then be calculated.
  • For multi-label classification, the system will print all classifications where the sigmoid(logits) scores exceed 0.5.

Test

  • If a test dataset is provided, set mode = test to evaluate the model's performance on the test dataset. The output varies slightly depending on the task. For multi-class classification (including binary classification), the results for each category will be output, along with a badcase file: test_1
  • When K-Folds is enabled, the logits from the K models for the same sample will be averaged during testing to calculate the evaluation metrics for the test set.
  • For multi-label classification, the evaluation results will only reflect the overall performance of the model across all labels.

Reference

License

This project is licensed under the Apache 2.0 license.

Citation

If you use this project in your research, please cite it as follows:

@misc{Text Classifier,
  title={Text Classifier: A tool for training text classifier using pytorch.},
  author={Shouxian Li},
  year={2024},
  howpublished={\url{https://github.com/stanleylsx/text_classifier_torch}},
}

About

Text classification repository built with Torch, featuring training tricks, acceleration methods, and model optimization techniques like distillation, compression, and pruning. Supports single-label and multi-label training with customizable configurations. 基于Torch的文本分类仓库,包含训练技巧、加速方法以及模型优化技术,如蒸馏、压缩和剪枝。支持单标签和多标签训练,并提供可自定义的配置选项。

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages