Skip to content

Commit

Permalink
added bert extractive summarization
Browse files Browse the repository at this point in the history
  • Loading branch information
huseinzol05 committed Sep 24, 2019
1 parent 77874ab commit 8596bf1
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Accuracy based on ROUGE-2.
1. LSTM RNN, test accuracy 16.13%
2. Dilated-CNN, test accuracy 15.54%
3. Multihead Attention, test accuracy 26.33%
4. BERT-Base

### [Generator](generator)

Expand Down
269 changes: 269 additions & 0 deletions extractive-summarization/4.bert-base.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('dataset-bert.pkl', 'rb') as fopen:\n",
" dataset = pickle.load(fopen)\n",
"dataset.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
"BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
"BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import bert\n",
"from bert import run_classifier\n",
"from bert import optimization\n",
"from bert import tokenization\n",
"from bert import modeling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
"tokenizer = tokenization.FullTokenizer(\n",
" vocab_file=BERT_VOCAB, do_lower_case=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epoch = 20\n",
"batch_size = 32\n",
"warmup_proportion = 0.1\n",
"num_train_steps = int(len(dataset['train_texts']) / batch_size * epoch)\n",
"num_warmup_steps = int(num_train_steps * warmup_proportion)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model:\n",
" def __init__(\n",
" self,\n",
" learning_rate = 2e-5,\n",
" ):\n",
" self.X = tf.placeholder(tf.int32, [None, None])\n",
" self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
" self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
" self.Y = tf.placeholder(tf.float32, [None, None])\n",
" self.mask = tf.placeholder(tf.int32, [None, None])\n",
" self.clss = tf.placeholder(tf.int32, [None, None])\n",
" mask = tf.cast(self.mask, tf.float32)\n",
" \n",
" model = modeling.BertModel(\n",
" config=bert_config,\n",
" is_training=True,\n",
" input_ids=self.X,\n",
" input_mask=self.input_masks,\n",
" token_type_ids=self.segment_ids,\n",
" use_one_hot_embeddings=False)\n",
" \n",
" outputs = model.get_sequence_output()\n",
" self.logits = tf.layers.dense(outputs, 1)\n",
" self.logits = tf.squeeze(self.logits, axis=-1)\n",
" self.logits = self.logits * mask\n",
" crossent = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.Y)\n",
" crossent = crossent * mask\n",
" crossent = tf.reduce_sum(crossent)\n",
" total_size = tf.reduce_sum(mask)\n",
" self.cost = tf.div_no_nan(crossent, total_size)\n",
" \n",
" self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
" num_train_steps, num_warmup_steps, False)\n",
" \n",
" l = tf.round(tf.sigmoid(self.logits))\n",
" self.accuracy = tf.reduce_mean(tf.cast(tf.boolean_mask(l, tf.equal(self.Y, 1)), tf.float32))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"sess = tf.InteractiveSession()\n",
"model = Model(learning_rate = 2e-5)\n",
"sess.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sess.run(tf.global_variables_initializer())\n",
"var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
"saver = tf.train.Saver(var_list = var_lists)\n",
"saver.restore(sess, BERT_INIT_CHKPNT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def pad_sentence_batch(sentence_batch, pad_int):\n",
" padded_seqs = []\n",
" seq_lens = []\n",
" max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
" for sentence in sentence_batch:\n",
" padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
" seq_lens.append(len(sentence))\n",
" return padded_seqs, seq_lens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_X = dataset['train_texts']\n",
"test_X = dataset['test_texts']\n",
"train_clss = dataset['train_clss']\n",
"test_clss = dataset['test_clss']\n",
"train_Y = dataset['train_labels']\n",
"test_Y = dataset['test_labels']\n",
"train_segments = dataset['train_segments']\n",
"test_segments = dataset['test_segments']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tqdm\n",
"\n",
"for e in range(epoch):\n",
" pbar = tqdm.tqdm(\n",
" range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
" train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
" for i in pbar:\n",
" index = min(i + batch_size, len(train_X))\n",
" batch_x, _ = pad_sentence_batch(train_X[i : index], 0)\n",
" batch_y, _ = pad_sentence_batch(train_Y[i : index], 0)\n",
" batch_segments, _ = pad_sentence_batch(train_segments[i : index], 0)\n",
" batch_clss, _ = pad_sentence_batch(train_clss[i : index], -1)\n",
" batch_clss = np.array(batch_clss)\n",
" batch_x = np.array(batch_x)\n",
" batch_mask = 1 - (batch_clss == -1)\n",
" batch_clss[batch_clss == -1] = 0\n",
" mask_src = 1 - (batch_x == 0)\n",
" feed = {model.X: batch_x,\n",
" model.Y: batch_y,\n",
" model.mask: batch_mask,\n",
" model.clss: batch_clss,\n",
" model.segment_ids: batch_segments,\n",
" model.input_masks: mask_src}\n",
" accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
" feed_dict = feed)\n",
" train_loss.append(loss)\n",
" train_acc.append(accuracy)\n",
" pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
" \n",
" pbar = tqdm.tqdm(\n",
" range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
" for i in pbar:\n",
" index = min(i + batch_size, len(test_X))\n",
" batch_x, _ = pad_sentence_batch(test_X[i : index], 0)\n",
" batch_y, _ = pad_sentence_batch(test_Y[i : index], 0)\n",
" batch_segments, _ = pad_sentence_batch(test_segments[i : index], 0)\n",
" batch_clss, _ = pad_sentence_batch(test_clss[i : index], -1)\n",
" batch_clss = np.array(batch_clss)\n",
" batch_x = np.array(batch_x)\n",
" batch_mask = 1 - (batch_clss == -1)\n",
" batch_clss[batch_clss == -1] = 0\n",
" mask_src = 1 - (batch_x == 0)\n",
" feed = {model.X: batch_x,\n",
" model.Y: batch_y,\n",
" model.mask: batch_mask,\n",
" model.clss: batch_clss,\n",
" model.segment_ids: batch_segments,\n",
" model.input_masks: mask_src}\n",
" accuracy, loss = sess.run([model.accuracy,model.cost],\n",
" feed_dict = feed)\n",
"\n",
" pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
" \n",
" print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
" np.mean(train_loss),np.mean(train_acc)))\n",
" print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
" np.mean(test_loss),np.mean(test_acc)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 8596bf1

Please sign in to comment.