Benoit Favre 2021-2024
This system converts a sequence of lowercase tokens without punctuation to a sequence of cased tokens with punctuation.
It is trained to predict both aspects at the token level in a multitask fashion, from fine-tuned BERT representations.
The model predicts the following recasing labels:
- lower: keep lowercase
- upper: convert to upper case
- capitalize: set first letter as upper case
- other: left as is, but could be processed with a list
And the following punctuation labels:
- o: no punctuation
- period: .
- comma: ,
- question: ?
- exclamation: !
Input tokens are batched as sequences of length 256 that are processed independently without overlap.
In training, batches containing less that 256 tokens are simulated by drawing uniformly a length and replacing all tokens and labels after that point with padding (called Cut-drop).
Changelog:
- v0.4: Retrain with latest transformers
- Add support for Zh and En models
- Fix generation when input is smaller than max length
Use your favourite method for installing Python requirements. For example:
python -m venv env
. env/bin/activate
pip3 install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
Predict from raw text:
python recasepunc.py predict checkpoint/path < input.txt > output.txt
All models are trained from the 1st 100M tokens from Common Crawl
{
"iteration": "23000",
"train_loss": "0.015077149430289864",
"valid_loss": "0.021484553813934326",
"valid_accuracy_case": "0.9517227564102564",
"valid_accuracy_punc": "0.9359975961538461",
"valid_fscore": "{0: 0.6016615629196167, 1: 0.6202345490455627, 2: 0.6219512224197388, 3: 0.42424243688583374, 4: 0.08571428805589676}",
"config": "{'seed': 871253, 'lang': 'it', 'flavor': 'dbmdz/bert-base-italian-uncased', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/it-100M.train.x', 'data/it-100M.train.y', 'data/it-100M.valid.x', 'data/it-100M.valid.y', 'checkpoints/it'], 'pad_token_id': 0, 'cls_token_id': 102, 'cls_token': '[CLS]', 'sep_token_id': 103, 'sep_token': '[SEP]'}"
}
{
"iteration": "17000",
"train_loss": "0.007012549160048366",
"valid_loss": "0.007463883130978315",
"valid_accuracy_case": "0.9967948717948718",
"valid_accuracy_punc": "0.9682491987179487",
"valid_fscore": "{0: 0.7668336033821106, 1: 0.7813194990158081, 2: 0.7200000286102295, 3: 0.8333333730697632, 4: 0.7272727489471436}",
"config": "{'seed': 871253, 'lang': 'zh-Hant', 'flavor': 'ckiplab/bert-base-chinese', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/zh-Hant-100M.train.x', 'data/zh-Hant-100M.train.y', 'data/zh-Hant-100M.valid.x', 'data/zh-Hant-100M.valid.y', 'checkpoints/zh-Hant'], 'pad_token_id': 0, 'cls_token_id': 101, 'cls_token': '[CLS]', 'sep_token_id': 102, 'sep_token': '[SEP]'}"
}
{
"iteration": "22000",
"train_loss": "0.01467611983884126",
"valid_loss": "0.02559371789296468",
"valid_accuracy_case": "0.9393028846153846",
"valid_accuracy_punc": "0.9404046474358975",
"valid_fscore": "{0: 0.6431096196174622, 1: 0.603951096534729, 2: 0.7078340649604797, 3: 0.6865671277046204, 4: 0}",
"config": "{'seed': 871253, 'lang': 'en', 'flavor': 'bert-base-uncased', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/en-100M.train.x', 'data/en-100M.train.y', 'data/en-100M.valid.x', 'data/en-100M.valid.y', 'checkpoints/en'], 'pad_token_id': 0, 'cls_token_id': 101, 'cls_token': '[CLS]', 'sep_token_id': 102, 'sep_token': '[SEP]'}"
}
{
"iteration": "24000",
"train_loss": "0.015482447233051061",
"valid_loss": "0.006200919071069131",
"valid_accuracy_case": "1.0",
"valid_accuracy_punc": "0.9691506410256411",
"valid_fscore": "{0: 0.8114132881164551, 1: 0.7968379855155945, 2: 0.8446389436721802, 3: 0.8421052694320679, 4: 0.3076923191547394}",
"config": "{'seed': 871253, 'lang': 'fr', 'flavor': 'flaubert/flaubert_base_uncased', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/fr-100M.train.x', 'data/fr-100M.train.y', 'data/fr-100M.valid.x', 'data/fr-100M.valid.y', 'checkpoints/fr'], 'pad_token_id': 2, 'cls_token_id': 0, 'cls_token': '<s>', 'sep_token_id': 1, 'sep_token': '</s>'}"
}
Notes: You need to modify file names adequately. Training tensors are precomputed and loaded to CPU memory, models and batches are moved to CUDA memory.
Stage 0: download text data
Stage 1: tokenize and normalize text with Moses tokenizer, and extract recasing and repunctuation labels
python recasepunc.py preprocess --lang $LANG < input.txt > input.case+punc
Stage 2: sub-tokenize with Flaubert/moses tokenizer, and generate pytorch tensors
python recasepunc.py tensorize input.case+punc input.case+punc.x input.case+punc.y --lang $LANG
Stage 3: train model
python recasepunc.py train train.x train.y valid.x valid.y checkpoint/path --lang $LANG
Stage 4: evaluate performance on a test set
python recasepunc.py eval test.x test.y checkpoint/path.iteration
Two scripts used to create the models are given as example of how to train for a new language:
./prepare.sh <lang>
for downloading data, creating sets, and preprocessing./train.sh <lang>
for trainging the model
Both assume the availability of a env.sh
script for loading the environment and setting up stuff.
requirements.freeze.txt
contains the package versions used for training.
You will need to modify recasepunc.py and set the BERT flavior for the new language and check that the tokenizer correctly handles punctuation. For French, we had to patch the tokenizer to keep input/punctuation synchronized.
This work was not published, but a similar model is described in "FullStop: Multilingual Deep Models for Punctuation Prediction", Frank et al, 2021.