Skip to content

Commit

Permalink
added gpt-2 chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
huseinzol05 committed Mar 13, 2019
1 parent 06826dd commit f7240a9
Show file tree
Hide file tree
Showing 5 changed files with 964 additions and 19 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
</p>
<p align="center">
<a href="https://github.com/huseinzol05/NLP-Models-Tensorflow/blob/master/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
<a href="#"><img src="https://img.shields.io/badge/total%20models-227--models-blue.svg"></a>
<a href="#"><img src="https://img.shields.io/badge/total%20models-228--models-blue.svg"></a>
</p>

---

**NLP-Models-Tensorflow**, Gathers machine learning and tensorflow deep learning models for NLP problems, with **code simplify**.
**NLP-Models-Tensorflow**, Gathers machine learning and tensorflow deep learning models for NLP problems, **code simplify inside Jupyter Notebooks 100%**.

## Table of contents
* [Text classification](https://github.com/huseinzol05/NLP-Models-Tensorflow#text-classification)
Expand Down Expand Up @@ -137,8 +137,10 @@ Original implementations are quite complex and not really beginner friendly. So
8. Capsule layers + LSTM Seq2Seq-API + Luong Attention + Beam Decoder
9. End-to-End Memory Network
10. Attention is All you need
11. Transformer-XL + LSTM
12. GPT-2 + LSTM

<details><summary>Complete list (48 notebooks)</summary>
<details><summary>Complete list (49 notebooks)</summary>

1. Basic cell Seq2Seq-manual
2. LSTM Seq2Seq-manual
Expand Down Expand Up @@ -188,6 +190,7 @@ Original implementations are quite complex and not really beginner friendly. So
46. Transformer-XL
47. Attention is all you need + Beam Search
48. Transformer-XL + LSTM
49. GPT-2 + LSTM

</details>

Expand Down
164 changes: 148 additions & 16 deletions chatbot/44.memory-network-lstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@
" vocab_proj = tf.layers.Dense(vocab_size_to)\n",
" state_proj = tf.layers.Dense(size_layer)\n",
" init_state = state_proj(tf.layers.flatten(answer))\n",
" print(init_state)\n",
" \n",
" helper = tf.contrib.seq2seq.TrainingHelper(\n",
" inputs = tf.nn.embedding_lookup(embedding, shift_right(self.Y)),\n",
Expand Down Expand Up @@ -415,15 +414,7 @@
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(\"dense/BiasAdd:0\", shape=(?, 256), dtype=float32)\n"
]
}
],
"outputs": [],
"source": [
"epoch = 20\n",
"batch_size = 16\n",
Expand All @@ -437,9 +428,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 1, avg loss: 2.862724, avg accuracy: 0.639000\n",
"epoch: 2, avg loss: 1.839667, avg accuracy: 0.734333\n",
"epoch: 3, avg loss: 1.732843, avg accuracy: 0.739833\n",
"epoch: 4, avg loss: 1.661211, avg accuracy: 0.747333\n",
"epoch: 5, avg loss: 1.586425, avg accuracy: 0.751333\n",
"epoch: 6, avg loss: 1.504489, avg accuracy: 0.753833\n",
"epoch: 7, avg loss: 1.416241, avg accuracy: 0.760167\n",
"epoch: 8, avg loss: 1.324014, avg accuracy: 0.769167\n",
"epoch: 9, avg loss: 1.231901, avg accuracy: 0.776333\n",
"epoch: 10, avg loss: 1.143001, avg accuracy: 0.782833\n",
"epoch: 11, avg loss: 1.061561, avg accuracy: 0.790667\n",
"epoch: 12, avg loss: 0.983756, avg accuracy: 0.804000\n",
"epoch: 13, avg loss: 0.901873, avg accuracy: 0.824167\n",
"epoch: 14, avg loss: 0.830711, avg accuracy: 0.846500\n",
"epoch: 15, avg loss: 0.773178, avg accuracy: 0.862167\n",
"epoch: 16, avg loss: 0.721180, avg accuracy: 0.869667\n",
"epoch: 17, avg loss: 0.670212, avg accuracy: 0.881000\n",
"epoch: 18, avg loss: 0.623576, avg accuracy: 0.889167\n",
"epoch: 19, avg loss: 0.589300, avg accuracy: 0.894833\n",
"epoch: 20, avg loss: 0.554041, avg accuracy: 0.900000\n"
]
}
],
"source": [
"for i in range(epoch):\n",
" total_loss, total_accuracy = 0, 0\n",
Expand All @@ -460,9 +478,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"row 1\n",
"QUESTION: i am a werewolf\n",
"REAL ANSWER: a werewolf\n",
"PREDICTED ANSWER: i am afraid not \n",
"\n",
"row 2\n",
"QUESTION: i was dreaming again\n",
"REAL ANSWER: i would think so\n",
"PREDICTED ANSWER: i am so ashamed \n",
"\n",
"row 3\n",
"QUESTION: the kitchen\n",
"REAL ANSWER: very nice\n",
"PREDICTED ANSWER: that is right \n",
"\n",
"row 4\n",
"QUESTION: the bedroom\n",
"REAL ANSWER: there is only one bed\n",
"PREDICTED ANSWER: thank you \n",
"\n"
]
}
],
"source": [
"for i in range(len(batch_x)):\n",
" print('row %d'%(i+1))\n",
Expand All @@ -473,9 +518,96 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"row 1\n",
"QUESTION: but david\n",
"REAL ANSWER: is here that\n",
"PREDICTED ANSWER: what is this \n",
"\n",
"row 2\n",
"QUESTION: hopeless it is hopeless\n",
"REAL ANSWER: tell ballet then back\n",
"PREDICTED ANSWER: i am afraid not \n",
"\n",
"row 3\n",
"QUESTION: miss price\n",
"REAL ANSWER: yes learning\n",
"PREDICTED ANSWER: yes doctor \n",
"\n",
"row 4\n",
"QUESTION: mr kessler wake up please\n",
"REAL ANSWER: is here are\n",
"PREDICTED ANSWER: no way \n",
"\n",
"row 5\n",
"QUESTION: there were witnesses\n",
"REAL ANSWER: why she out\n",
"PREDICTED ANSWER: that is right \n",
"\n",
"row 6\n",
"QUESTION: what about it\n",
"REAL ANSWER: not you are\n",
"PREDICTED ANSWER: i am afraid not \n",
"\n",
"row 7\n",
"QUESTION: go on ask them\n",
"REAL ANSWER: i just home\n",
"PREDICTED ANSWER: goodbye darling \n",
"\n",
"row 8\n",
"QUESTION: beware the moon\n",
"REAL ANSWER: seen hi is he\n",
"PREDICTED ANSWER: that is right \n",
"\n",
"row 9\n",
"QUESTION: did you hear that\n",
"REAL ANSWER: is down what\n",
"PREDICTED ANSWER: it is killing me \n",
"\n",
"row 10\n",
"QUESTION: i heard that\n",
"REAL ANSWER: it here not\n",
"PREDICTED ANSWER: it is okay \n",
"\n",
"row 11\n",
"QUESTION: the hound of the baskervilles\n",
"REAL ANSWER: heard\n",
"PREDICTED ANSWER: yes i remember \n",
"\n",
"row 12\n",
"QUESTION: it is moving\n",
"REAL ANSWER: not you hear\n",
"PREDICTED ANSWER: it is okay \n",
"\n",
"row 13\n",
"QUESTION: nice doggie good boy\n",
"REAL ANSWER: bill stupid\n",
"PREDICTED ANSWER: yes doctor \n",
"\n",
"row 14\n",
"QUESTION: it sounds far away\n",
"REAL ANSWER: that pecos baby seen hi\n",
"PREDICTED ANSWER: what is dead \n",
"\n",
"row 15\n",
"QUESTION: debbie klein cried a lot\n",
"REAL ANSWER: is will srai not\n",
"PREDICTED ANSWER: i am afraid not \n",
"\n",
"row 16\n",
"QUESTION: what are you doing here\n",
"REAL ANSWER: is know look i\n",
"PREDICTED ANSWER: that is right \n",
"\n"
]
}
],
"source": [
"batch_x, seq_x = pad_sentence_batch(X_test[:batch_size], PAD, maxlen_question)\n",
"batch_y, seq_y = pad_sentence_batch(Y_test[:batch_size], PAD, maxlen_answer)\n",
Expand Down
Loading

0 comments on commit f7240a9

Please sign in to comment.