Skip to content

Commit

Permalink
Updated Named Entity Recognition using Transformers example for Keras…
Browse files Browse the repository at this point in the history
… 3 (keras-team#1817)

* Updated the ner keras 3 example

* generated files are added

* date format corrected
  • Loading branch information
sitamgithub-MSIT authored Apr 5, 2024
1 parent 18c6544 commit a78d832
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 55 deletions.
64 changes: 31 additions & 33 deletions examples/nlp/ipynb/ner_transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"# Named Entity Recognition using Transformers\n",
"\n",
"**Author:** [Varun Singh](https://www.linkedin.com/in/varunsingh2/)<br>\n",
"**Date created:** Jun 23, 2021<br>\n",
"**Last modified:** Jun 24, 2021<br>\n",
"**Date created:** 2021/06/23<br>\n",
"**Last modified:** 2024/04/05<br>\n",
"**Description:** NER using the Transformers and data from CoNLL 2003 shared task."
]
},
Expand Down Expand Up @@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -59,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -69,8 +69,8 @@
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import os\n",
"import keras\n",
"from keras import ops\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from keras import layers\n",
Expand All @@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -123,8 +123,7 @@
" out1 = self.layernorm1(inputs + attn_output)\n",
" ffn_output = self.ffn(out1)\n",
" ffn_output = self.dropout2(ffn_output, training=training)\n",
" return self.layernorm2(out1 + ffn_output)\n",
""
" return self.layernorm2(out1 + ffn_output)\n"
]
},
{
Expand All @@ -138,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -154,12 +153,11 @@
" self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
"\n",
" def call(self, inputs):\n",
" maxlen = tf.shape(inputs)[-1]\n",
" positions = tf.range(start=0, limit=maxlen, delta=1)\n",
" maxlen = ops.shape(inputs)[-1]\n",
" positions = ops.arange(start=0, stop=maxlen, step=1)\n",
" position_embeddings = self.pos_emb(positions)\n",
" token_embeddings = self.token_emb(inputs)\n",
" return token_embeddings + position_embeddings\n",
""
" return token_embeddings + position_embeddings\n"
]
},
{
Expand All @@ -173,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -199,8 +197,7 @@
" x = self.ff(x)\n",
" x = self.dropout2(x, training=training)\n",
" x = self.ff_final(x)\n",
" return x\n",
""
" return x\n"
]
},
{
Expand All @@ -214,7 +211,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -235,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -281,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -313,7 +310,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -348,7 +345,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -370,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -390,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -440,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -456,9 +453,9 @@
" from_logits=False, reduction=None\n",
" )\n",
" loss = loss_fn(y_true, y_pred)\n",
" mask = tf.cast((y_true > 0), dtype=tf.float32)\n",
" mask = ops.cast((y_true > 0), dtype=\"float32\")\n",
" loss = loss * mask\n",
" return tf.reduce_sum(loss) / tf.reduce_sum(mask)\n",
" return ops.sum(loss) / ops.sum(mask)\n",
"\n",
"\n",
"loss = CustomNonPaddingTokenLoss()"
Expand All @@ -475,12 +472,13 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"tf.config.run_functions_eagerly(True)\n",
"ner_model.compile(optimizer=\"adam\", loss=loss)\n",
"ner_model.fit(train_dataset, epochs=10)\n",
"\n",
Expand All @@ -494,7 +492,7 @@
"sample_input = tokenize_and_convert_to_ids(\n",
" \"eu rejects german call to boycott british lamb\"\n",
")\n",
"sample_input = tf.reshape(sample_input, shape=[1, -1])\n",
"sample_input = ops.reshape(sample_input, shape=[1, -1])\n",
"print(sample_input)\n",
"\n",
"output = ner_model.predict(sample_input)\n",
Expand All @@ -519,7 +517,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -531,10 +529,10 @@
"\n",
" for x, y in dataset:\n",
" output = ner_model.predict(x, verbose=0)\n",
" predictions = np.argmax(output, axis=-1)\n",
" predictions = np.reshape(predictions, [-1])\n",
" predictions = ops.argmax(output, axis=-1)\n",
" predictions = ops.reshape(predictions, [-1])\n",
"\n",
" true_tag_ids = np.reshape(y, [-1])\n",
" true_tag_ids = ops.reshape(y, [-1])\n",
"\n",
" mask = (true_tag_ids > 0) & (predictions > 0)\n",
" true_tag_ids = true_tag_ids[mask]\n",
Expand Down Expand Up @@ -603,4 +601,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
23 changes: 12 additions & 11 deletions examples/nlp/md/ner_transformers.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Named Entity Recognition using Transformers

**Author:** [Varun Singh](https://www.linkedin.com/in/varunsingh2/)<br>
**Date created:** Jun 23, 2021<br>
**Last modified:** Jun 24, 2021<br>
**Date created:** 2021/06/23<br>
**Last modified:** 2024/04/05<br>
**Description:** NER using the Transformers and data from CoNLL 2003 shared task.


Expand Down Expand Up @@ -65,8 +65,8 @@ import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import os
import keras
from keras import ops
import numpy as np
import tensorflow as tf
from keras import layers
Expand Down Expand Up @@ -124,8 +124,8 @@ class TokenAndPositionEmbedding(layers.Layer):
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

def call(self, inputs):
maxlen = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
maxlen = ops.shape(inputs)[-1]
positions = ops.arange(start=0, stop=maxlen, step=1)
position_embeddings = self.pos_emb(positions)
token_embeddings = self.token_emb(inputs)
return token_embeddings + position_embeddings
Expand Down Expand Up @@ -330,9 +330,9 @@ class CustomNonPaddingTokenLoss(keras.losses.Loss):
from_logits=False, reduction=None
)
loss = loss_fn(y_true, y_pred)
mask = tf.cast((y_true > 0), dtype=tf.float32)
mask = ops.cast((y_true > 0), dtype="float32")
loss = loss * mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
return ops.sum(loss) / ops.sum(mask)


loss = CustomNonPaddingTokenLoss()
Expand All @@ -343,6 +343,7 @@ loss = CustomNonPaddingTokenLoss()


```python
tf.config.run_functions_eagerly(True)
ner_model.compile(optimizer="adam", loss=loss)
ner_model.fit(train_dataset, epochs=10)

Expand All @@ -356,7 +357,7 @@ def tokenize_and_convert_to_ids(text):
sample_input = tokenize_and_convert_to_ids(
"eu rejects german call to boycott british lamb"
)
sample_input = tf.reshape(sample_input, shape=[1, -1])
sample_input = ops.reshape(sample_input, shape=[1, -1])
print(sample_input)

output = ner_model.predict(sample_input)
Expand Down Expand Up @@ -409,10 +410,10 @@ def calculate_metrics(dataset):

for x, y in dataset:
output = ner_model.predict(x, verbose=0)
predictions = np.argmax(output, axis=-1)
predictions = np.reshape(predictions, [-1])
predictions = ops.argmax(output, axis=-1)
predictions = ops.reshape(predictions, [-1])

true_tag_ids = np.reshape(y, [-1])
true_tag_ids = ops.reshape(y, [-1])

mask = (true_tag_ids > 0) & (predictions > 0)
true_tag_ids = true_tag_ids[mask]
Expand Down
24 changes: 13 additions & 11 deletions examples/nlp/ner_transformers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Title: Named Entity Recognition using Transformers
Author: [Varun Singh](https://www.linkedin.com/in/varunsingh2/)
Date created: Jun 23, 2021
Last modified: Jun 24, 2021
Date created: 2021/06/23
Last modified: 2024/04/05
Description: NER using the Transformers and data from CoNLL 2003 shared task.
Accelerator: GPU
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
"""

"""
Expand Down Expand Up @@ -37,8 +38,8 @@

os.environ["KERAS_BACKEND"] = "tensorflow"

import os
import keras
from keras import ops
import numpy as np
import tensorflow as tf
from keras import layers
Expand Down Expand Up @@ -94,8 +95,8 @@ def __init__(self, maxlen, vocab_size, embed_dim):
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

def call(self, inputs):
maxlen = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
maxlen = ops.shape(inputs)[-1]
positions = ops.arange(start=0, stop=maxlen, step=1)
position_embeddings = self.pos_emb(positions)
token_embeddings = self.token_emb(inputs)
return token_embeddings + position_embeddings
Expand Down Expand Up @@ -270,9 +271,9 @@ def call(self, y_true, y_pred):
from_logits=False, reduction=None
)
loss = loss_fn(y_true, y_pred)
mask = tf.cast((y_true > 0), dtype=tf.float32)
mask = ops.cast((y_true > 0), dtype="float32")
loss = loss * mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
return ops.sum(loss) / ops.sum(mask)


loss = CustomNonPaddingTokenLoss()
Expand All @@ -281,6 +282,7 @@ def call(self, y_true, y_pred):
## Compile and fit the model
"""

tf.config.run_functions_eagerly(True)
ner_model.compile(optimizer="adam", loss=loss)
ner_model.fit(train_dataset, epochs=10)

Expand All @@ -294,7 +296,7 @@ def tokenize_and_convert_to_ids(text):
sample_input = tokenize_and_convert_to_ids(
"eu rejects german call to boycott british lamb"
)
sample_input = tf.reshape(sample_input, shape=[1, -1])
sample_input = ops.reshape(sample_input, shape=[1, -1])
print(sample_input)

output = ner_model.predict(sample_input)
Expand All @@ -317,10 +319,10 @@ def calculate_metrics(dataset):

for x, y in dataset:
output = ner_model.predict(x, verbose=0)
predictions = np.argmax(output, axis=-1)
predictions = np.reshape(predictions, [-1])
predictions = ops.argmax(output, axis=-1)
predictions = ops.reshape(predictions, [-1])

true_tag_ids = np.reshape(y, [-1])
true_tag_ids = ops.reshape(y, [-1])

mask = (true_tag_ids > 0) & (predictions > 0)
true_tag_ids = true_tag_ids[mask]
Expand Down

0 comments on commit a78d832

Please sign in to comment.