|
| 1 | +# BERT Classification Tutorial |
| 2 | + |
| 3 | + |
| 4 | +## はじめに |
| 5 | + |
| 6 | + |
| 7 | +## Installation & データセット準備 |
| 8 | + |
| 9 | +本実装は**Python 3.10以上**での実行を想定しています。 |
| 10 | +Python 3.10は、match文の導入やwith文の改善など様々な利便性の向上がなされている他、[Pythonが高速化の計画を進めていること](https://forest.watch.impress.co.jp/docs/news/1451751.html)もあり、早めに新しいPythonに適応しておくことのメリットは大きいと考えたためです。 |
| 11 | + |
| 12 | +また、Python 3.10では、type hints (型注釈)が以前のバージョンより自然に書けるようになっており、今までよりも堅牢かつ可読性の高いコードを書きやすくなっています。 |
| 13 | +そのため、公開実装のためのPythonとしても優れていると考えました。 |
| 14 | + |
| 15 | +### Install with poetry |
| 16 | + |
| 17 | +```bash |
| 18 | +poetry install |
| 19 | +``` |
| 20 | + |
| 21 | +### Install with conda & pip |
| 22 | + |
| 23 | + |
| 24 | +https://pytorch.org/get-started/locally/ |
| 25 | + |
| 26 | +```bash |
| 27 | +conda create -n bert-classification-tutorial python=3.10 |
| 28 | + |
| 29 | +conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia |
| 30 | + |
| 31 | +pip install tqdm "transformers[ja,sentencepiece]" classopt tokenizers numpy pandas more-itertools scikit-learn scipy |
| 32 | +``` |
| 33 | + |
| 34 | + |
| 35 | +### データセット作成 |
| 36 | + |
| 37 | +本実装では、分類対象のテキストとしてRONDHUIT社が公開する[livedoorニュースコーパス](http://www.rondhuit.com/download.html#ldcc)を用います。 |
| 38 | +livedoorニュースコーパスは、9つのカテゴリのニュース記事が集約されたデータセットです。 |
| 39 | +通常、ニュース記事のタイトルと本文を用いて、そのニュース記事がどのカテゴリにあてはまるかを分類する9値分類を行います。 |
| 40 | + |
| 41 | +本実装では、以下のコマンドを実行すればデータセットの準備が完了するようになっています。 |
| 42 | + |
| 43 | + |
| 44 | +```bash |
| 45 | +bash src/download.sh |
| 46 | + |
| 47 | +poetry run python src/prepare.py |
| 48 | +// python src/prepare.py |
| 49 | +``` |
| 50 | + |
| 51 | + |
| 52 | +流れとしては、まず`src/download.sh`がデータセットのダウンロードと生データの展開を行います。 |
| 53 | + |
| 54 | +次に、`src/prepare.py`を実行することで、生データをJSONL形式(1行ごとにJSON形式のデータが書き込まれている形式)に変換します。 |
| 55 | +その際、NFKC正規化などの前処理も実行します。 |
| 56 | + |
| 57 | +さらに、分類モデルの訓練のため、分類先となるカテゴリを文字列から数値に変換し、その変換表を保存します。 |
| 58 | + |
| 59 | +また、全データを訓練(train):開発(val):テスト(test)=8:1:1の割合に分割します。 |
| 60 | +これにより、訓練中に開発セットを用いて、モデルが過学習していないかの確認が行えるようになります。 |
| 61 | +テストセットは最終的な評価にのみ用います。 |
| 62 | + |
| 63 | + |
| 64 | +### 訓練 |
| 65 | + |
| 66 | +以下のコマンドを実行することで、`cl-tohoku/bert-base-japanese-v2`を用いたテキスト分類モデルの訓練が実行できます。 |
| 67 | + |
| 68 | +```bash |
| 69 | +poetry run python src/train.py --model_name cl-tohoku/bert-base-japanese-v2 |
| 70 | +``` |
| 71 | + |
| 72 | +この時、`--model_name`に与える引数を例えば`bert-base-multilingual-cased`にすることで、多言語BERTを用いた学習が実行できます。 |
| 73 | + |
| 74 | +また、ほとんどの設定をコマンドライン引数として与えら得れるようにしているので、以下のように複数の設定を変更して実行することも可能です。 |
| 75 | + |
| 76 | +```bash |
| 77 | +poetry run python src/train.py \ |
| 78 | + --model_name studio-ousia/luke-japanese-base-lite \ |
| 79 | + --batch_size 32 \ |
| 80 | + --epochs 10 \ |
| 81 | + --lr 5e-5 \ |
| 82 | + --num_warmup_epochs 1 |
| 83 | +``` |
| 84 | + |
| 85 | +本実装では学習後のモデルは保存せず、訓練のたびに評価値を算出し、評価値のみを保存するようにしています。 |
| 86 | +学習済みモデルを保存→保存済みモデルを読み込んで評価、という流れの実装をよく見ますが、この実装は実験の途中でどのモデルを使用していたのか忘れてしまったり、モデルの構造が学習時と変わってしまっていたり、評価用データを間違えてしまったり、といった問題が発生しやすいと考えています。 |
| 87 | + |
| 88 | +そこで本実装では、訓練のたびに必要な評価を行ってその結果のみを保存しておき、モデルは保存しない方針を採用しました。 |
| 89 | +これにより、モデルの構造を変化させたり、学習・評価データを変化させた場合でも、訓練をし直すだけで常に間違いのない結果を得られます。 |
| 90 | +研究における実験プロセスの中では、間違いのない実験結果を積み重ねていくことが、研究を進めていく上で最も重要だと考えているので、間違いが発生しづらいこの方針はスジがよいと考えています。 |
| 91 | + |
| 92 | +本実装において、実験結果は `outputs/[モデル名]/[年月日]/[時分秒]`のディレクトリに保存されます。 |
| 93 | +実際には、以下のようなディレクトリ構造になります。 |
| 94 | + |
| 95 | +``` |
| 96 | +outputs/bert-base-multilingual-cased |
| 97 | +└── 2023-01-13 |
| 98 | + └── 05-38-02 |
| 99 | + ├── config.json |
| 100 | + ├── log.csv |
| 101 | + ├── test-metrics.json |
| 102 | + └── val-metrics.json |
| 103 | +``` |
| 104 | + |
| 105 | +`config.json`が実験時の設定で、このファイルに記述してある値を用いることで、同じ実験を再現することができるようにしてあります。 |
| 106 | +また、`log.csv`に学習過程における開発セットでのepochごとの評価値を記録してあります。 |
| 107 | +そして、`val-metrics.json`と`test-metrics.json`に、開発セットの評価値が最もよかった時点でのモデルを用いた、開発セットとテストセットに対する評価値を記録してあります。 |
| 108 | + |
| 109 | +実際の`test-metrics.json`は以下のようになっています。 |
| 110 | + |
| 111 | +```json:test-metrics.json |
| 112 | +{ |
| 113 | + "loss": 2.845567681340744, |
| 114 | + "accuracy": 0.9619565217391305, |
| 115 | + "precision": 0.9561782755165722, |
| 116 | + "recall": 0.9562792450871114, |
| 117 | + "f1": 0.9559338777925345 |
| 118 | +} |
| 119 | +``` |
| 120 | + |
| 121 | +## 評価実験 |
| 122 | + |
| 123 | +最後に、本実装によって、livedoorニュースコーパスの9値分類を行う評価実験を実施しました。 |
| 124 | + |
| 125 | +注意点ですが、実験は単一の乱数シード値で1度しか実施しておらず、分割交差検証も行っていないので、実験結果の正確性は高くありません。 |
| 126 | +したがって、以下の結果は過度に信用せず、参考程度に見てもらうよう、お願いいたします。 |
| 127 | + |
| 128 | +では、結果の表を以下に示します。 |
| 129 | +baseサイズのモデルとlargeサイズのモデルの2種類にモデルを大別して結果をまとめました。 |
| 130 | +なお、Accuracy (正解率)以外の値、つまりPresicion (精度)、Recall (再現率)、F1はmacro平均を取った値です。 |
| 131 | +また、すべての値は%表記です。 |
| 132 | + |
| 133 | +| base models | Accuracy | Precision | Recall | F1 | |
| 134 | +| ------------------------------------------------------------------------------------------------------------------------- | --------- | --------- | --------- | --------- | |
| 135 | +| [cl-tohoku/bert-base-japanese-v2](https://huggingface.co/cl-tohoku/bert-base-japanese-v2) | **97.15** | **96.82** | **96.55** | **96.64** | |
| 136 | +| [cl-tohoku/bert-base-japanese-char-v2](https://huggingface.co/cl-tohoku/bert-base-japanese-char-v2) | 96.20 | 95.54 | 95.21 | 95.34 | |
| 137 | +| [cl-tohoku/bert-base-japanese](https://huggingface.co/cl-tohoku/bert-base-japanese) | 96.47 | 96.15 | 95.67 | 95.83 | |
| 138 | +| [cl-tohoku/bert-base-japanese-whole-word-masking](https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking) | 96.74 | 96.43 | 95.97 | 96.13 | |
| 139 | +| [cl-tohoku/bert-base-japanese-char](https://huggingface.co/cl-tohoku/bert-base-japanese-char) | 95.65 | 94.98 | 94.88 | 94.89 | |
| 140 | +| | | | | | |
| 141 | +| [studio-ousia/luke-japanese-base-lite](https://huggingface.co/studio-ousia/luke-japanese-base-lite) | 96.88 | 96.53 | 96.47 | 96.48 | |
| 142 | +| | | | | | |
| 143 | +| [bert-base-multilingual-cased](https://huggingface.co/bert-base-multilingual-cased) | 96.20 | 95.62 | 95.63 | 95.59 | |
| 144 | +| [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | 96.20 | 95.65 | 95.60 | 95.61 | |
| 145 | +| [studio-ousia/mluke-base-lite](https://huggingface.co/studio-ousia/mluke-base-lite) | 96.47 | 95.82 | 95.94 | 95.86 | |
| 146 | + |
| 147 | +まず、baseサイズのモデルの結果について観察すると、今回の実験では東北大BERTのバージョン2 (bert-base-japanese-v2)が最も高い性能になったことがわかります。 |
| 148 | +Accuracyが97.15、F1が96.64と、かなり高い割合で正しく分類することができていると思います。 |
| 149 | +東北大が公開しているモデルのbert-base-japanese-whole-word-maskingと比較して、bert-base-japanese-v2の方が性能が高く、東北大BERTの中だと、今後は最初にbert-base-japanese-v2を使って問題なさそうだという印象です。 |
| 150 | + |
| 151 | +次点はStudio Ousiaの日本語LUKEで、こちらも非常に高い割合で正しく分類を行えていると思います。 |
| 152 | + |
| 153 | +文字ベースのモデル(cl-tohoku/bert-base-japanese-char-v2など)は、他のモデルと比較して若干性能が低いですが、十分高い性能であるといえると思います。 |
| 154 | + |
| 155 | +多言語モデルの中では、Studio OusiaのmLUKEが最も高い性能になりました。 |
| 156 | + |
| 157 | + |
| 158 | +| large models | Accuracy | Precision | Recall | F1 | |
| 159 | +| ----------------------------------------------------------------------------------------------------- | --------- | --------- | --------- | --------- | |
| 160 | +| [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese) | **97.69** | **97.50** | 96.84 | **97.10** | |
| 161 | +| [studio-ousia/luke-japanese-large-lite](https://huggingface.co/studio-ousia/luke-japanese-large-lite) | 97.55 | 97.38 | 96.85 | 97.06 | |
| 162 | +| | | | | | |
| 163 | +| [xlm-roberta-large](https://huggingface.co/xlm-roberta-large) | 97.15 | 96.73 | 96.71 | 96.70 | |
| 164 | +| [studio-ousia/mluke-large-lite](https://huggingface.co/studio-ousia/mluke-large-lite) | 97.42 | 97.25 | **96.97** | 97.08 | |
| 165 | + |
| 166 | +次に、largeサイズのモデルの結果について観察すると、AccuracyとF1では東北大BERT (large)が最も高い性能になりましたが、Studio Ousiaの日本語LUKEや多言語LUKEと比較して、ほとんど同じ性能になりました。 |
| 167 | +全体として、baseサイズのモデルよりも高い性能となっており、モデルサイズを増大させることによる性能向上が観察できました。 |
| 168 | + |
| 169 | +## 参考文献 |
| 170 | + |
| 171 | +- [【実装解説】日本語版BERTでlivedoorニュース分類:Google Colaboratoryで(PyTorch)](https://qiita.com/sugulu_Ogawa_ISID/items/697bd03499c1de9cf082) |
| 172 | +- [Livedoorニュースコーパスを文書分類にすぐ使えるように整形する](https://radiology-nlp.hatenablog.com/entry/2019/11/25/124219) |
| 173 | +- https://github.com/yoheikikuta/bert-japanese/blob/master/notebook/finetune-to-livedoor-corpus.ipynb |
| 174 | +- https://github.com/sonoisa/t5-japanese/blob/main/t5_japanese_classification.ipynb |
| 175 | + |
| 176 | +## 引用 |
| 177 | + |
| 178 | +作者: [Hayato Tsukagoshi](https://hpprc.dev) \ |
| 179 | +email: [research.tsukagoshi.hayato@gmail.com](mailto:research.tsukagoshi.hayato@gmail.com) |
| 180 | + |
| 181 | + |
| 182 | +```bibtex |
| 183 | +@misc{ |
| 184 | + hayato-tsukagoshi-2023-bert-classification-tutorial, |
| 185 | + title = {{BERT Classification Tutorial}}, |
| 186 | + author = {Hayato Tsukagoshi}, |
| 187 | + year = {2023}, |
| 188 | + publisher = {GitHub}, |
| 189 | + journal = {GitHub repository}, |
| 190 | + howpublished = {\url{https://github.com/hppRC/bert-classification-tutorial}}, |
| 191 | + url = {https://github.com/hppRC/bert-classification-tutorial}, |
| 192 | +} |
| 193 | +``` |
0 commit comments