Skip to content

Commit

Permalink
Start bert news
Browse files Browse the repository at this point in the history
  • Loading branch information
kuk committed Mar 10, 2020
1 parent a83be5d commit 599c8c2
Show file tree
Hide file tree
Showing 22 changed files with 570 additions and 222 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@

test:
pytest -vv --pep8 --flakes --cov slovnet --cov-report term-missing
pytest -vv \
--pep8 --flakes \
--cov slovnet --cov-report term-missing \
slovnet

ci:
pytest -vv slovnet/tests/test_api.py
Expand Down
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ For every column top 3 results are highlighted. In each case `slovnet` and `deep
</tbody>
</table>

## License

Source code of `slovnet` is distributed under MIT license (allows modification and commercial usage)

## Support

- Chat — https://telegram.me/natural_language_processing
Expand All @@ -312,15 +308,24 @@ Rent GPU:

```bash
vast search offers | grep '1 x RTX 2080 Ti'
vast create instance 474463 --image alexkuk/my-vast --disk 20
vast destroy instance 482511
vast create instance 420232 --image alexkuk/my-vast --disk 20
vast destroy instance 488468
watch vast show instances

ssh -Nf vast -L 8888:localhost:8888 -L 6006:localhost:6006
http://localhost:8888/
http://localhost:6006
http://localhost:8888/notebooks/
http://localhost:6006/

scp ~/.slovnet.json vast:~
rsync --exclude data --exclude notes -rv . vast:~/slovnet
rsync -u --exclude data --exclude runs -rv 'vast:~/slovnet/*' .

```

Intall dev:

```bash
pip3 install -e slovnet
pip3 install -r slovnet/requirements/dev.txt

```
5 changes: 2 additions & 3 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
torch>=1.1.0
matplotlib
boto3
tqdm

navec>=0.6.0
nerus>=1.4.0
razdel>=0.4.0

corus>=0.5.0
277 changes: 276 additions & 1 deletion scripts/01_bert_news/main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,284 @@
"metadata": {},
"outputs": [],
"source": [
"%run -n main.py\n",
"%run main.py\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !mkdir -p data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/taiga/Fontanka.tar.gz -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/ods/gazeta_v1.csv.zip -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/ods/interfax_v1.csv.zip -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/lenta-ru-news.csv.gz -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2014.tar.bz2 -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2015-part1.tar.bz2 -P data/raw\n",
"# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2015-part2.tar.bz2 -P data/raw"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# LOADS = {\n",
"# 'gazeta_v1.csv.zip': load_ods_gazeta,\n",
"# 'interfax_v1.csv.zip': load_ods_interfax,\n",
"# 'Fontanka.tar.gz': load_taiga_fontanka,\n",
"# 'lenta-ru-news.csv.gz': load_lenta,\n",
"# 'news-articles-2015-part1.tar.bz2': load_buriy_news,\n",
"# 'news-articles-2015-part2.tar.bz2': load_buriy_news,\n",
"# 'news-articles-2014.tar.bz2': load_buriy_news,\n",
"# }\n",
"\n",
"\n",
"# lines = [] # Requires 15Gb RAM\n",
"# for name in listdir('data/raw'):\n",
"# path = 'data/raw/' + name\n",
"# records = LOADS[name](path)\n",
"# for record in log_progress(records, desc=name):\n",
"# line = re.sub('\\s+', ' ', record.text)\n",
"# lines.append(line)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# seed(1)\n",
"# shuffle(lines)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# cap = 1000\n",
"# dump_lines(lines[:cap], 'data/test.txt')\n",
"# dump_lines(log_progress(lines[cap:]), 'data/train.txt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# upload('data/test.txt')\n",
"# upload('data/train.txt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not exists('data/test.txt'):\n",
" download('data/test.txt')\n",
" download('data/train.txt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not exists('rubert/vocab.txt'):\n",
" for name in ['vocab.txt', 'emb.pt', 'encoder.pt', 'mlm.pt']:\n",
" download('rubert/' + name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = get_device()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"items = list(load_lines('rubert/vocab.txt'))\n",
"vocab = BERTVocab(items)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = BERTConfig(\n",
" vocab_size=50106,\n",
" seq_len=512,\n",
" emb_dim=768,\n",
" layers_num=12,\n",
" heads_num=12,\n",
" hidden_dim=3072,\n",
" dropout=0.1,\n",
" norm_eps=1e-12\n",
")\n",
"emb = BERTEmbedding(\n",
" config.vocab_size, config.seq_len, config.emb_dim,\n",
" config.dropout, config.norm_eps\n",
")\n",
"emb.position.requires_grad = False # fix pos emb to train on short seqs\n",
"encoder = BERTEncoder(\n",
" config.layers_num, config.emb_dim, config.heads_num, config.hidden_dim,\n",
" config.dropout, config.norm_eps\n",
")\n",
"mlm = BERTMLMHead(config.emb_dim, config.vocab_size)\n",
"model = BERTMLM(emb, encoder, mlm)\n",
"\n",
"load_model(model, 'rubert')\n",
"model = model.to(device)\n",
"\n",
"criterion = flatten_cross_entropy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1)\n",
"seed(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encode = BERTMLMEncoder(\n",
" vocab,\n",
" seq_len=128,\n",
" batch_size=32,\n",
" shuffle_size=10000\n",
")\n",
"\n",
"lines = load_lines('data/test.txt')\n",
"batches = encode(lines)\n",
"test_batches = [_.to(device) for _ in batches]\n",
"\n",
"lines = load_lines('data/train.txt')\n",
"batches = encode(lines)\n",
"train_batches = (_.to(device) for _ in batches)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"board = Board('01', 'runs')\n",
"train_board = board.section('01_train')\n",
"test_board = board.section('02_test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
"model, optimizer = amp.initialize(model, optimizer, opt_level='O2')\n",
"scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_meter = MLMScoreMeter()\n",
"test_meter = MLMScoreMeter()\n",
"\n",
"accum_steps = 64 # 2K batch\n",
"log_steps = 256\n",
"eval_steps = 512\n",
"save_steps = eval_steps * 10\n",
"\n",
"model.train()\n",
"optimizer.zero_grad()\n",
"\n",
"for step, batch in log_progress(enumerate(train_batches)):\n",
" batch = process_batch(model, criterion, batch)\n",
" batch.loss /= accum_steps\n",
" \n",
" with amp.scale_loss(batch.loss, optimizer) as scaled:\n",
" scaled.backward()\n",
"\n",
" score = score_batch(batch, ks=())\n",
" train_meter.add(score)\n",
"\n",
" if every(step, log_steps):\n",
" train_meter.write(train_board)\n",
" train_meter.reset()\n",
"\n",
" if every(step, accum_steps):\n",
" optimizer.step()\n",
" scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" if every(step, eval_steps):\n",
" batches = infer_batches(model, criterion, test_batches)\n",
" scores = score_batches(batches)\n",
" test_meter.extend(scores)\n",
" test_meter.write(test_board)\n",
" test_meter.reset()\n",
" \n",
" if every(step, save_steps):\n",
" dump_model(model, 'model')\n",
" for name in ['emb.pt', 'encoder.pt', 'mlm.pt']:\n",
" upload('model/' + name)\n",
" \n",
" board.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 599c8c2

Please sign in to comment.