From d1470147355f227a5c7784253a56ac8f6302cbb0 Mon Sep 17 00:00:00 2001
From: husein zolkepli
Date: Sun, 17 Nov 2019 10:19:32 +0800
Subject: [PATCH] added albert base text classification
---
README.md | 6 +-
.../77.transfer-learning-albert-base.ipynb | 775 ++++++++++++++++++
2 files changed, 779 insertions(+), 2 deletions(-)
create mode 100644 text-classification/77.transfer-learning-albert-base.ipynb
diff --git a/README.md b/README.md
index 92bdac3..efdd396 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
-
+
---
@@ -389,8 +389,9 @@ Trained on [English sentiment dataset](text-classification/data), accuracy table
10. BERT
11. Dynamic Memory Network
12. XL-net
+13. ALBERT
-Complete list (76 notebooks)
+Complete list (77 notebooks)
1. Basic cell RNN
2. Basic cell RNN + Hinge
@@ -468,6 +469,7 @@ Trained on [English sentiment dataset](text-classification/data), accuracy table
74. Transfer learning BERT Base drop 6 layers
75. Transfer learning BERT Large drop 12 layers
76. Transfer learning XL-net Base
+77. Transfer learning ALBERT
diff --git a/text-classification/77.transfer-learning-albert-base.ipynb b/text-classification/77.transfer-learning-albert-base.ipynb
new file mode 100644
index 0000000..ec4a54c
--- /dev/null
+++ b/text-classification/77.transfer-learning-albert-base.ipynb
@@ -0,0 +1,775 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !wget https://storage.googleapis.com/tfhub-modules/google/albert_base/2.tar.gz\n",
+ "# !tar -zxf 2.tar.gz"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !wget https://raw.githubusercontent.com/huseinzol05/NLP-Models-Tensorflow/master/text-classification/data.zip\n",
+ "# !unzip data.zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !wget https://raw.githubusercontent.com/huseinzol05/NLP-Models-Tensorflow/master/text-classification/utils.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !pip3 install albert-tensorflow"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "30k-clean.model 30k-clean.vocab albert_config.json\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!ls assets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "variables.data-00000-of-00001 variables.index\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!ls variables"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
+ "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
+ " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/lamb_optimizer.py:34: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from albert import modeling\n",
+ "from albert import optimization\n",
+ "from albert import tokenization\n",
+ "import tensorflow as tf\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
+ "\n",
+ "INFO:tensorflow:loading sentence piece model\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = tokenization.FullTokenizer(\n",
+ " vocab_file='assets/30k-clean.vocab', do_lower_case=True,\n",
+ " spm_model_file='assets/30k-clean.model')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['▁hus', 'ein', '▁is', '▁so', '▁cute']"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.tokenize('husein is so cute')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from utils import *\n",
+ "from sklearn.model_selection import train_test_split"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['negative', 'positive']\n",
+ "10662\n",
+ "10662\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainset = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
+ "trainset.data, trainset.target = separate_dataset(trainset,1.0)\n",
+ "print (trainset.target_names)\n",
+ "print (len(trainset.data))\n",
+ "print (len(trainset.target))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MAX_SEQ_LENGTH = 100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10662/10662 [00:00<00:00, 11143.98it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tqdm import tqdm\n",
+ "\n",
+ "input_ids, input_masks, segment_ids = [], [], []\n",
+ "\n",
+ "for text in tqdm(trainset.data):\n",
+ " tokens_a = tokenizer.tokenize(text.lower())\n",
+ " if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
+ " tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
+ " tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
+ " segment_id = [0] * len(tokens)\n",
+ " input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
+ " input_mask = [1] * len(input_id)\n",
+ " padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
+ " input_id += padding\n",
+ " input_mask += padding\n",
+ " segment_id += padding\n",
+ " \n",
+ " input_ids.append(input_id)\n",
+ " input_masks.append(input_mask)\n",
+ " segment_ids.append(segment_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:116: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "albert_config = modeling.AlbertConfig.from_json_file('assets/albert_config.json')\n",
+ "albert_config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['▁moving', '▁uneven', '▁success']"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.tokenize(trainset.data[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "epoch = 10\n",
+ "batch_size = 32\n",
+ "warmup_proportion = 0.1\n",
+ "num_train_steps = int(len(input_ids) / batch_size * epoch)\n",
+ "num_warmup_steps = int(num_train_steps * warmup_proportion)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Model:\n",
+ " def __init__(\n",
+ " self,\n",
+ " dimension_output,\n",
+ " learning_rate = 2e-5,\n",
+ " ):\n",
+ " self.X = tf.placeholder(tf.int32, [None, None])\n",
+ " self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
+ " self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
+ " self.Y = tf.placeholder(tf.int32, [None])\n",
+ " \n",
+ " model = modeling.AlbertModel(\n",
+ " config=albert_config,\n",
+ " is_training=False,\n",
+ " input_ids=self.X,\n",
+ " input_mask=self.input_masks,\n",
+ " token_type_ids=self.segment_ids,\n",
+ " use_one_hot_embeddings=False)\n",
+ " \n",
+ " output_layer = model.get_pooled_output()\n",
+ " self.logits = tf.layers.dense(output_layer, dimension_output)\n",
+ " \n",
+ " self.cost = tf.reduce_mean(\n",
+ " tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits = self.logits, labels = self.Y\n",
+ " )\n",
+ " )\n",
+ " \n",
+ " self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
+ " num_train_steps, num_warmup_steps, False)\n",
+ " \n",
+ " correct_pred = tf.equal(\n",
+ " tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
+ " )\n",
+ " self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:194: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:507: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:588: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:1025: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:253: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use keras.layers.dense instead.\n",
+ "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n",
+ "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
+ "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n",
+ "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:36: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:41: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py:409: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Deprecated in favor of operator or tf.math.divide.\n",
+ "INFO:tensorflow:++++++ warmup starts at step 0, for 333 steps ++++++\n",
+ "INFO:tensorflow:using adamw\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:101: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
+ "\n",
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
+ ]
+ }
+ ],
+ "source": [
+ "dimension_output = 2\n",
+ "learning_rate = 5e-5\n",
+ "\n",
+ "tf.reset_default_graph()\n",
+ "sess = tf.InteractiveSession()\n",
+ "model = Model(\n",
+ " dimension_output,\n",
+ " learning_rate\n",
+ ")\n",
+ "\n",
+ "sess.run(tf.global_variables_initializer())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
+ "saver = tf.train.Saver(var_list = var_lists)\n",
+ "var_lists"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use standard file APIs to check for files with this prefix.\n",
+ "INFO:tensorflow:Restoring parameters from variables/variables\n"
+ ]
+ }
+ ],
+ "source": [
+ "saver.restore(sess, 'variables/variables')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_input_ids, test_input_ids, train_input_masks, test_input_masks, train_segment_ids, test_segment_ids, train_Y, test_Y = train_test_split(\n",
+ " input_ids, input_masks, segment_ids, trainset.target, test_size = 0.2\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:59<00:00, 4.46it/s, accuracy=0.765, cost=0.51] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 12.26it/s, accuracy=0.667, cost=0.721]\n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 0, pass acc: 0.000000, current acc: 0.737615\n",
+ "time taken: 65.3192994594574\n",
+ "epoch: 0, training loss: 0.588583, training acc: 0.693689, valid loss: 0.545611, valid acc: 0.737615\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.70it/s, accuracy=0.647, cost=0.495]\n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.28it/s, accuracy=0.714, cost=0.53] \n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 1, pass acc: 0.737615, current acc: 0.777711\n",
+ "time taken: 61.85024404525757\n",
+ "epoch: 1, training loss: 0.557109, training acc: 0.742843, valid loss: 0.498011, valid acc: 0.777711\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.71it/s, accuracy=0.941, cost=0.268]\n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.27it/s, accuracy=0.619, cost=0.792]\n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 2, pass acc: 0.777711, current acc: 0.778157\n",
+ "time taken: 61.79904627799988\n",
+ "epoch: 2, training loss: 0.437586, training acc: 0.820040, valid loss: 0.557257, valid acc: 0.778157\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.71it/s, accuracy=0.882, cost=0.416] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.31it/s, accuracy=0.714, cost=0.909]\n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 3, pass acc: 0.778157, current acc: 0.796933\n",
+ "time taken: 61.780051469802856\n",
+ "epoch: 3, training loss: 0.340006, training acc: 0.870001, valid loss: 0.635498, valid acc: 0.796933\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.70it/s, accuracy=0.941, cost=0.24] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.28it/s, accuracy=0.714, cost=0.876]\n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time taken: 61.891632318496704\n",
+ "epoch: 4, training loss: 0.279899, training acc: 0.908796, valid loss: 0.670878, valid acc: 0.795526\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.70it/s, accuracy=0.941, cost=0.249] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.26it/s, accuracy=0.762, cost=1.02] \n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 5, pass acc: 0.796933, current acc: 0.799053\n",
+ "time taken: 61.81794476509094\n",
+ "epoch: 5, training loss: 0.227987, training acc: 0.935997, valid loss: 0.876410, valid acc: 0.799053\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.70it/s, accuracy=0.941, cost=0.236] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.27it/s, accuracy=0.714, cost=1.08] \n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time taken: 61.82991981506348\n",
+ "epoch: 6, training loss: 0.200658, training acc: 0.948308, valid loss: 0.842970, valid acc: 0.795057\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.71it/s, accuracy=0.941, cost=0.259] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.24it/s, accuracy=0.714, cost=1.17] \n",
+ "train minibatch loop: 0%| | 0/267 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time taken: 61.805060386657715\n",
+ "epoch: 7, training loss: 0.169402, training acc: 0.958860, valid loss: 0.827865, valid acc: 0.791307\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "train minibatch loop: 100%|██████████| 267/267 [00:56<00:00, 4.71it/s, accuracy=0.941, cost=0.169] \n",
+ "test minibatch loop: 100%|██████████| 67/67 [00:05<00:00, 13.27it/s, accuracy=0.762, cost=1.04] "
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time taken: 61.79657506942749\n",
+ "epoch: 8, training loss: 0.127402, training acc: 0.971992, valid loss: 0.854281, valid acc: 0.790615\n",
+ "\n",
+ "break epoch:9\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tqdm import tqdm\n",
+ "import time\n",
+ "\n",
+ "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
+ "\n",
+ "while True:\n",
+ " lasttime = time.time()\n",
+ " if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
+ " print('break epoch:%d\\n' % (EPOCH))\n",
+ " break\n",
+ "\n",
+ " train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
+ " pbar = tqdm(\n",
+ " range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
+ " )\n",
+ " for i in pbar:\n",
+ " index = min(i + batch_size, len(train_input_ids))\n",
+ " batch_x = train_input_ids[i: index]\n",
+ " batch_masks = train_input_masks[i: index]\n",
+ " batch_segment = train_segment_ids[i: index]\n",
+ " batch_y = train_Y[i: index]\n",
+ " acc, cost, _ = sess.run(\n",
+ " [model.accuracy, model.cost, model.optimizer],\n",
+ " feed_dict = {\n",
+ " model.Y: batch_y,\n",
+ " model.X: batch_x,\n",
+ " model.segment_ids: batch_segment,\n",
+ " model.input_masks: batch_masks\n",
+ " },\n",
+ " )\n",
+ " assert not np.isnan(cost)\n",
+ " train_loss += cost\n",
+ " train_acc += acc\n",
+ " pbar.set_postfix(cost = cost, accuracy = acc)\n",
+ "\n",
+ " pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
+ " for i in pbar:\n",
+ " index = min(i + batch_size, len(test_input_ids))\n",
+ " batch_x = test_input_ids[i: index]\n",
+ " batch_masks = test_input_masks[i: index]\n",
+ " batch_segment = test_segment_ids[i: index]\n",
+ " batch_y = test_Y[i: index]\n",
+ " acc, cost = sess.run(\n",
+ " [model.accuracy, model.cost],\n",
+ " feed_dict = {\n",
+ " model.Y: batch_y,\n",
+ " model.X: batch_x,\n",
+ " model.segment_ids: batch_segment,\n",
+ " model.input_masks: batch_masks\n",
+ " },\n",
+ " )\n",
+ " test_loss += cost\n",
+ " test_acc += acc\n",
+ " pbar.set_postfix(cost = cost, accuracy = acc)\n",
+ "\n",
+ " train_loss /= len(train_input_ids) / batch_size\n",
+ " train_acc /= len(train_input_ids) / batch_size\n",
+ " test_loss /= len(test_input_ids) / batch_size\n",
+ " test_acc /= len(test_input_ids) / batch_size\n",
+ "\n",
+ " if test_acc > CURRENT_ACC:\n",
+ " print(\n",
+ " 'epoch: %d, pass acc: %f, current acc: %f'\n",
+ " % (EPOCH, CURRENT_ACC, test_acc)\n",
+ " )\n",
+ " CURRENT_ACC = test_acc\n",
+ " CURRENT_CHECKPOINT = 0\n",
+ " else:\n",
+ " CURRENT_CHECKPOINT += 1\n",
+ " \n",
+ " print('time taken:', time.time() - lasttime)\n",
+ " print(\n",
+ " 'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
+ " % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
+ " )\n",
+ " EPOCH += 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}