Skip to content

Commit 797a353

Browse files
committed
Implemented foundational code and wrote documentation
1 parent ec0a9c3 commit 797a353

12 files changed

Lines changed: 1812 additions & 0 deletions

File tree

.gitignore

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/data
2+
/datasets
3+
/outputs
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
10+
# C extensions
11+
*.so
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
share/python-wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.nox/
47+
.coverage
48+
.coverage.*
49+
.cache
50+
nosetests.xml
51+
coverage.xml
52+
*.cover
53+
*.py,cover
54+
.hypothesis/
55+
.pytest_cache/
56+
cover/
57+
58+
# Translations
59+
*.mo
60+
*.pot
61+
62+
# Django stuff:
63+
*.log
64+
local_settings.py
65+
db.sqlite3
66+
db.sqlite3-journal
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
.pybuilder/
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
85+
# IPython
86+
profile_default/
87+
ipython_config.py
88+
89+
# pyenv
90+
# For a library or package, you might want to ignore these files since the code is
91+
# intended to run in multiple environments; otherwise, check them in:
92+
# .python-version
93+
94+
# pipenv
95+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
97+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
98+
# install all needed dependencies.
99+
#Pipfile.lock
100+
101+
# poetry
102+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103+
# This is especially recommended for binary packages to ensure reproducibility, and is more
104+
# commonly ignored for libraries.
105+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106+
#poetry.lock
107+
108+
# pdm
109+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110+
#pdm.lock
111+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112+
# in version control.
113+
# https://pdm.fming.dev/#use-with-ide
114+
.pdm.toml
115+
116+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117+
__pypackages__/
118+
119+
# Celery stuff
120+
celerybeat-schedule
121+
celerybeat.pid
122+
123+
# SageMath parsed files
124+
*.sage.py
125+
126+
# Environments
127+
.env
128+
.venv
129+
env/
130+
venv/
131+
ENV/
132+
env.bak/
133+
venv.bak/
134+
135+
# Spyder project settings
136+
.spyderproject
137+
.spyproject
138+
139+
# Rope project settings
140+
.ropeproject
141+
142+
# mkdocs documentation
143+
/site
144+
145+
# mypy
146+
.mypy_cache/
147+
.dmypy.json
148+
dmypy.json
149+
150+
# Pyre type checker
151+
.pyre/
152+
153+
# pytype static type analyzer
154+
.pytype/
155+
156+
# Cython debug symbols
157+
cython_debug/
158+
159+
# PyCharm
160+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162+
# and can be added to the global gitignore or merged into this file. For a more nuclear
163+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
164+
#.idea/

.tool-versions

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python 3.10.9

README.md

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# BERT Classification Tutorial
2+
3+
4+
## はじめに
5+
6+
7+
## Installation & データセット準備
8+
9+
本実装は**Python 3.10以上**での実行を想定しています。
10+
Python 3.10は、match文の導入やwith文の改善など様々な利便性の向上がなされている他、[Pythonが高速化の計画を進めていること](https://forest.watch.impress.co.jp/docs/news/1451751.html)もあり、早めに新しいPythonに適応しておくことのメリットは大きいと考えたためです。
11+
12+
また、Python 3.10では、type hints (型注釈)が以前のバージョンより自然に書けるようになっており、今までよりも堅牢かつ可読性の高いコードを書きやすくなっています。
13+
そのため、公開実装のためのPythonとしても優れていると考えました。
14+
15+
### Install with poetry
16+
17+
```bash
18+
poetry install
19+
```
20+
21+
### Install with conda & pip
22+
23+
24+
https://pytorch.org/get-started/locally/
25+
26+
```bash
27+
conda create -n bert-classification-tutorial python=3.10
28+
29+
conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia
30+
31+
pip install tqdm "transformers[ja,sentencepiece]" classopt tokenizers numpy pandas more-itertools scikit-learn scipy
32+
```
33+
34+
35+
### データセット作成
36+
37+
本実装では、分類対象のテキストとしてRONDHUIT社が公開する[livedoorニュースコーパス](http://www.rondhuit.com/download.html#ldcc)を用います。
38+
livedoorニュースコーパスは、9つのカテゴリのニュース記事が集約されたデータセットです。
39+
通常、ニュース記事のタイトルと本文を用いて、そのニュース記事がどのカテゴリにあてはまるかを分類する9値分類を行います。
40+
41+
本実装では、以下のコマンドを実行すればデータセットの準備が完了するようになっています。
42+
43+
44+
```bash
45+
bash src/download.sh
46+
47+
poetry run python src/prepare.py
48+
// python src/prepare.py
49+
```
50+
51+
52+
流れとしては、まず`src/download.sh`がデータセットのダウンロードと生データの展開を行います。
53+
54+
次に、`src/prepare.py`を実行することで、生データをJSONL形式(1行ごとにJSON形式のデータが書き込まれている形式)に変換します。
55+
その際、NFKC正規化などの前処理も実行します。
56+
57+
さらに、分類モデルの訓練のため、分類先となるカテゴリを文字列から数値に変換し、その変換表を保存します。
58+
59+
また、全データを訓練(train):開発(val):テスト(test)=8:1:1の割合に分割します。
60+
これにより、訓練中に開発セットを用いて、モデルが過学習していないかの確認が行えるようになります。
61+
テストセットは最終的な評価にのみ用います。
62+
63+
64+
### 訓練
65+
66+
以下のコマンドを実行することで、`cl-tohoku/bert-base-japanese-v2`を用いたテキスト分類モデルの訓練が実行できます。
67+
68+
```bash
69+
poetry run python src/train.py --model_name cl-tohoku/bert-base-japanese-v2
70+
```
71+
72+
この時、`--model_name`に与える引数を例えば`bert-base-multilingual-cased`にすることで、多言語BERTを用いた学習が実行できます。
73+
74+
また、ほとんどの設定をコマンドライン引数として与えら得れるようにしているので、以下のように複数の設定を変更して実行することも可能です。
75+
76+
```bash
77+
poetry run python src/train.py \
78+
--model_name studio-ousia/luke-japanese-base-lite \
79+
--batch_size 32 \
80+
--epochs 10 \
81+
--lr 5e-5 \
82+
--num_warmup_epochs 1
83+
```
84+
85+
本実装では学習後のモデルは保存せず、訓練のたびに評価値を算出し、評価値のみを保存するようにしています。
86+
学習済みモデルを保存→保存済みモデルを読み込んで評価、という流れの実装をよく見ますが、この実装は実験の途中でどのモデルを使用していたのか忘れてしまったり、モデルの構造が学習時と変わってしまっていたり、評価用データを間違えてしまったり、といった問題が発生しやすいと考えています。
87+
88+
そこで本実装では、訓練のたびに必要な評価を行ってその結果のみを保存しておき、モデルは保存しない方針を採用しました。
89+
これにより、モデルの構造を変化させたり、学習・評価データを変化させた場合でも、訓練をし直すだけで常に間違いのない結果を得られます。
90+
研究における実験プロセスの中では、間違いのない実験結果を積み重ねていくことが、研究を進めていく上で最も重要だと考えているので、間違いが発生しづらいこの方針はスジがよいと考えています。
91+
92+
本実装において、実験結果は `outputs/[モデル名]/[年月日]/[時分秒]`のディレクトリに保存されます。
93+
実際には、以下のようなディレクトリ構造になります。
94+
95+
```
96+
outputs/bert-base-multilingual-cased
97+
└── 2023-01-13
98+
└── 05-38-02
99+
├── config.json
100+
├── log.csv
101+
├── test-metrics.json
102+
└── val-metrics.json
103+
```
104+
105+
`config.json`が実験時の設定で、このファイルに記述してある値を用いることで、同じ実験を再現することができるようにしてあります。
106+
また、`log.csv`に学習過程における開発セットでのepochごとの評価値を記録してあります。
107+
そして、`val-metrics.json``test-metrics.json`に、開発セットの評価値が最もよかった時点でのモデルを用いた、開発セットとテストセットに対する評価値を記録してあります。
108+
109+
実際の`test-metrics.json`は以下のようになっています。
110+
111+
```json:test-metrics.json
112+
{
113+
"loss": 2.845567681340744,
114+
"accuracy": 0.9619565217391305,
115+
"precision": 0.9561782755165722,
116+
"recall": 0.9562792450871114,
117+
"f1": 0.9559338777925345
118+
}
119+
```
120+
121+
## 評価実験
122+
123+
最後に、本実装によって、livedoorニュースコーパスの9値分類を行う評価実験を実施しました。
124+
125+
注意点ですが、実験は単一の乱数シード値で1度しか実施しておらず、分割交差検証も行っていないので、実験結果の正確性は高くありません。
126+
したがって、以下の結果は過度に信用せず、参考程度に見てもらうよう、お願いいたします。
127+
128+
では、結果の表を以下に示します。
129+
baseサイズのモデルとlargeサイズのモデルの2種類にモデルを大別して結果をまとめました。
130+
なお、Accuracy (正解率)以外の値、つまりPresicion (精度)、Recall (再現率)、F1はmacro平均を取った値です。
131+
また、すべての値は%表記です。
132+
133+
| base models | Accuracy | Precision | Recall | F1 |
134+
| ------------------------------------------------------------------------------------------------------------------------- | --------- | --------- | --------- | --------- |
135+
| [cl-tohoku/bert-base-japanese-v2](https://huggingface.co/cl-tohoku/bert-base-japanese-v2) | **97.15** | **96.82** | **96.55** | **96.64** |
136+
| [cl-tohoku/bert-base-japanese-char-v2](https://huggingface.co/cl-tohoku/bert-base-japanese-char-v2) | 96.20 | 95.54 | 95.21 | 95.34 |
137+
| [cl-tohoku/bert-base-japanese](https://huggingface.co/cl-tohoku/bert-base-japanese) | 96.47 | 96.15 | 95.67 | 95.83 |
138+
| [cl-tohoku/bert-base-japanese-whole-word-masking](https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking) | 96.74 | 96.43 | 95.97 | 96.13 |
139+
| [cl-tohoku/bert-base-japanese-char](https://huggingface.co/cl-tohoku/bert-base-japanese-char) | 95.65 | 94.98 | 94.88 | 94.89 |
140+
| | | | | |
141+
| [studio-ousia/luke-japanese-base-lite](https://huggingface.co/studio-ousia/luke-japanese-base-lite) | 96.88 | 96.53 | 96.47 | 96.48 |
142+
| | | | | |
143+
| [bert-base-multilingual-cased](https://huggingface.co/bert-base-multilingual-cased) | 96.20 | 95.62 | 95.63 | 95.59 |
144+
| [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | 96.20 | 95.65 | 95.60 | 95.61 |
145+
| [studio-ousia/mluke-base-lite](https://huggingface.co/studio-ousia/mluke-base-lite) | 96.47 | 95.82 | 95.94 | 95.86 |
146+
147+
まず、baseサイズのモデルの結果について観察すると、今回の実験では東北大BERTのバージョン2 (bert-base-japanese-v2)が最も高い性能になったことがわかります。
148+
Accuracyが97.15、F1が96.64と、かなり高い割合で正しく分類することができていると思います。
149+
東北大が公開しているモデルのbert-base-japanese-whole-word-maskingと比較して、bert-base-japanese-v2の方が性能が高く、東北大BERTの中だと、今後は最初にbert-base-japanese-v2を使って問題なさそうだという印象です。
150+
151+
次点はStudio Ousiaの日本語LUKEで、こちらも非常に高い割合で正しく分類を行えていると思います。
152+
153+
文字ベースのモデル(cl-tohoku/bert-base-japanese-char-v2など)は、他のモデルと比較して若干性能が低いですが、十分高い性能であるといえると思います。
154+
155+
多言語モデルの中では、Studio OusiaのmLUKEが最も高い性能になりました。
156+
157+
158+
| large models | Accuracy | Precision | Recall | F1 |
159+
| ----------------------------------------------------------------------------------------------------- | --------- | --------- | --------- | --------- |
160+
| [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese) | **97.69** | **97.50** | 96.84 | **97.10** |
161+
| [studio-ousia/luke-japanese-large-lite](https://huggingface.co/studio-ousia/luke-japanese-large-lite) | 97.55 | 97.38 | 96.85 | 97.06 |
162+
| | | | | |
163+
| [xlm-roberta-large](https://huggingface.co/xlm-roberta-large) | 97.15 | 96.73 | 96.71 | 96.70 |
164+
| [studio-ousia/mluke-large-lite](https://huggingface.co/studio-ousia/mluke-large-lite) | 97.42 | 97.25 | **96.97** | 97.08 |
165+
166+
次に、largeサイズのモデルの結果について観察すると、AccuracyとF1では東北大BERT (large)が最も高い性能になりましたが、Studio Ousiaの日本語LUKEや多言語LUKEと比較して、ほとんど同じ性能になりました。
167+
全体として、baseサイズのモデルよりも高い性能となっており、モデルサイズを増大させることによる性能向上が観察できました。
168+
169+
## 参考文献
170+
171+
- [【実装解説】日本語版BERTでlivedoorニュース分類:Google Colaboratoryで(PyTorch)](https://qiita.com/sugulu_Ogawa_ISID/items/697bd03499c1de9cf082)
172+
- [Livedoorニュースコーパスを文書分類にすぐ使えるように整形する](https://radiology-nlp.hatenablog.com/entry/2019/11/25/124219)
173+
- https://github.com/yoheikikuta/bert-japanese/blob/master/notebook/finetune-to-livedoor-corpus.ipynb
174+
- https://github.com/sonoisa/t5-japanese/blob/main/t5_japanese_classification.ipynb
175+
176+
## 引用
177+
178+
作者: [Hayato Tsukagoshi](https://hpprc.dev) \
179+
email: [research.tsukagoshi.hayato@gmail.com](mailto:research.tsukagoshi.hayato@gmail.com)
180+
181+
182+
```bibtex
183+
@misc{
184+
hayato-tsukagoshi-2023-bert-classification-tutorial,
185+
title = {{BERT Classification Tutorial}},
186+
author = {Hayato Tsukagoshi},
187+
year = {2023},
188+
publisher = {GitHub},
189+
journal = {GitHub repository},
190+
howpublished = {\url{https://github.com/hppRC/bert-classification-tutorial}},
191+
url = {https://github.com/hppRC/bert-classification-tutorial},
192+
}
193+
```

metrics.csv

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
model_name,loss,accuracy,precision,recall,f1
2+
cl-tohoku/bert-large-japanese,1.8636954729344049,0.9769021739130435,0.9750378536488287,0.9683866534528303,0.9709886899708837
3+
studio-ousia/mluke-large-lite,2.462220841489793,0.9741847826086957,0.9724700612250123,0.9696959807981327,0.9708172275698685
4+
studio-ousia/luke-japanese-large-lite,2.201103634626159,0.9755434782608695,0.9737537895304548,0.9684910092403877,0.9706240612459918
5+
xlm-roberta-large,2.355739393742229,0.9714673913043478,0.9672815110515471,0.9670801683912984,0.9670110692671131
6+
cl-tohoku/bert-base-japanese-v2,2.552672197131197,0.9714673913043478,0.968187644904622,0.9655247714489725,0.9664236077143215
7+
studio-ousia/luke-japanese-base-lite,2.9518878437862126,0.96875,0.9653349755858047,0.9646585861518412,0.9647895664897193
8+
cl-tohoku/bert-base-japanese-whole-word-masking,2.754634548019132,0.967391304347826,0.9643093526155185,0.9596735451242144,0.9613454189698499
9+
studio-ousia/mluke-base-lite,2.323811716562056,0.9646739130434783,0.9581897824634483,0.9593519920275893,0.9585694731642673
10+
cl-tohoku/bert-base-japanese,2.5077272631308953,0.9646739130434783,0.9614887868830029,0.9566602153441597,0.958255241534905
11+
xlm-roberta-base,2.822538128649087,0.9619565217391305,0.9564778703744456,0.9560111793601512,0.9560647293589007
12+
bert-base-multilingual-cased,2.845567681340744,0.9619565217391305,0.9561782755165722,0.9562792450871114,0.9559338777925345
13+
cl-tohoku/bert-base-japanese-char-v2,3.1727957457263507,0.9619565217391305,0.955424825465865,0.952106175189902,0.953410810716683
14+
cl-tohoku/bert-base-japanese-char,3.3101620054196403,0.9565217391304348,0.9498454773625516,0.9487548103898411,0.9488789542839646

0 commit comments

Comments
 (0)