Skip to content

Commit

Permalink
Update "Text classification with switch transformer" to use Keras 3 (#…
Browse files Browse the repository at this point in the history
…1678)

* update example to keras 3

* fix training error

* add .md and and .ipynb files
  • Loading branch information
divyashreepathihalli authored Dec 9, 2023
1 parent 92665e4 commit 7646708
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
"import keras\n",
"from keras import ops\n",
"from keras import layers"
]
},
{
Expand Down Expand Up @@ -145,8 +145,8 @@
" self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
"\n",
" def call(self, x):\n",
" maxlen = tf.shape(x)[-1]\n",
" positions = tf.range(start=0, limit=maxlen, delta=1)\n",
" maxlen = ops.shape(x)[-1]\n",
" positions = ops.arange(start=0, stop=maxlen, step=1)\n",
" positions = self.pos_emb(positions)\n",
" x = self.token_emb(x)\n",
" return x + positions\n",
Expand Down Expand Up @@ -205,19 +205,17 @@
" # each expert per token. expert_mask [tokens_per_batch, num_experts] contains\n",
" # the expert with the highest router probability in one\u2212hot format.\n",
"\n",
" num_experts = tf.shape(expert_mask)[-1]\n",
" num_experts = ops.shape(expert_mask)[-1]\n",
" # Get the fraction of tokens routed to each expert.\n",
" # density is a vector of length num experts that sums to 1.\n",
" density = tf.reduce_mean(expert_mask, axis=0)\n",
" density = ops.mean(expert_mask, axis=0)\n",
" # Get fraction of probability mass assigned to each expert from the router\n",
" # across all tokens. density_proxy is a vector of length num experts that sums to 1.\n",
" density_proxy = tf.reduce_mean(router_probs, axis=0)\n",
" density_proxy = ops.mean(router_probs, axis=0)\n",
" # Want both vectors to have uniform allocation (1/num experts) across all\n",
" # num_expert elements. The two vectors will be pushed towards uniform allocation\n",
" # when the dot product is minimized.\n",
" loss = tf.reduce_mean(density_proxy * density) * tf.cast(\n",
" (num_experts**2), tf.dtypes.float32\n",
" )\n",
" loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), \"float32\")\n",
" return loss\n",
""
]
Expand Down Expand Up @@ -254,47 +252,45 @@
"\n",
" if training:\n",
" # Add noise for exploration across experts.\n",
" router_logits += tf.random.uniform(\n",
" router_logits += keras.random.uniform(\n",
" shape=router_logits.shape, minval=0.9, maxval=1.1\n",
" )\n",
" # Probabilities for each token of what expert it should be sent to.\n",
" router_probs = keras.activations.softmax(router_logits, axis=-1)\n",
" # Get the top\u22121 expert for each token. expert_gate is the top\u22121 probability\n",
" # from the router for each token. expert_index is what expert each token\n",
" # is going to be routed to.\n",
" expert_gate, expert_index = tf.math.top_k(router_probs, k=1)\n",
" expert_gate, expert_index = ops.top_k(router_probs, k=1)\n",
" # expert_mask shape: [tokens_per_batch, num_experts]\n",
" expert_mask = tf.one_hot(expert_index, depth=self.num_experts)\n",
" expert_mask = ops.one_hot(expert_index, self.num_experts)\n",
" # Compute load balancing loss.\n",
" aux_loss = load_balanced_loss(router_probs, expert_mask)\n",
" self.add_loss(aux_loss)\n",
" # Experts have a fixed capacity, ensure we do not exceed it. Construct\n",
" # the batch indices, to each expert, with position in expert make sure that\n",
" # not more that expert capacity examples can be routed to each expert.\n",
" position_in_expert = tf.cast(\n",
" tf.math.cumsum(expert_mask, axis=0) * expert_mask, tf.dtypes.int32\n",
" position_in_expert = ops.cast(\n",
" ops.cumsum(expert_mask, axis=0) * expert_mask, \"int32\"\n",
" )\n",
" # Keep only tokens that fit within expert capacity.\n",
" expert_mask *= tf.cast(\n",
" tf.math.less(\n",
" tf.cast(position_in_expert, tf.dtypes.int32), self.expert_capacity\n",
" ),\n",
" tf.dtypes.float32,\n",
" expert_mask *= ops.cast(\n",
" ops.less(ops.cast(position_in_expert, \"int32\"), self.expert_capacity),\n",
" \"float32\",\n",
" )\n",
" expert_mask_flat = tf.reduce_sum(expert_mask, axis=-1)\n",
" expert_mask_flat = ops.sum(expert_mask, axis=-1)\n",
" # Mask out the experts that have overflowed the expert capacity.\n",
" expert_gate *= expert_mask_flat\n",
" # Combine expert outputs and scaling with router probability.\n",
" # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" combined_tensor = tf.expand_dims(\n",
" combined_tensor = ops.expand_dims(\n",
" expert_gate\n",
" * expert_mask_flat\n",
" * tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),\n",
" * ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),\n",
" -1,\n",
" ) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)\n",
" ) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)\n",
" # Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]\n",
" # that is 1 if the token gets routed to the corresponding expert.\n",
" dispatch_tensor = tf.cast(combined_tensor, tf.dtypes.float32)\n",
" dispatch_tensor = ops.cast(combined_tensor, \"float32\")\n",
"\n",
" return dispatch_tensor, combined_tensor\n",
""
Expand Down Expand Up @@ -333,33 +329,33 @@
" super().__init__()\n",
"\n",
" def call(self, inputs):\n",
" batch_size = tf.shape(inputs)[0]\n",
" num_tokens_per_example = tf.shape(inputs)[1]\n",
" batch_size = ops.shape(inputs)[0]\n",
" num_tokens_per_example = ops.shape(inputs)[1]\n",
"\n",
" # inputs shape: [num_tokens_per_batch, embed_dim]\n",
" inputs = tf.reshape(inputs, [num_tokens_per_batch, self.embed_dim])\n",
" inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])\n",
" # dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]\n",
" # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" dispatch_tensor, combine_tensor = self.router(inputs)\n",
" # expert_inputs shape: [num_experts, expert_capacity, embed_dim]\n",
" expert_inputs = tf.einsum(\"ab,acd->cdb\", inputs, dispatch_tensor)\n",
" expert_inputs = tf.reshape(\n",
" expert_inputs = ops.einsum(\"ab,acd->cdb\", inputs, dispatch_tensor)\n",
" expert_inputs = ops.reshape(\n",
" expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]\n",
" )\n",
" # Dispatch to experts\n",
" expert_input_list = tf.unstack(expert_inputs, axis=0)\n",
" expert_input_list = ops.unstack(expert_inputs, axis=0)\n",
" expert_output_list = [\n",
" self.experts[idx](expert_input)\n",
" for idx, expert_input in enumerate(expert_input_list)\n",
" ]\n",
" # expert_outputs shape: [expert_capacity, num_experts, embed_dim]\n",
" expert_outputs = tf.stack(expert_output_list, axis=1)\n",
" expert_outputs = ops.stack(expert_output_list, axis=1)\n",
" # expert_outputs_combined shape: [tokens_per_batch, embed_dim]\n",
" expert_outputs_combined = tf.einsum(\n",
" expert_outputs_combined = ops.einsum(\n",
" \"abc,xba->xc\", expert_outputs, combine_tensor\n",
" )\n",
" # output shape: [batch_size, num_tokens_per_example, embed_dim]\n",
" outputs = tf.reshape(\n",
" outputs = ops.reshape(\n",
" expert_outputs_combined,\n",
" [batch_size, num_tokens_per_example, self.embed_dim],\n",
" )\n",
Expand Down Expand Up @@ -397,7 +393,7 @@
" self.dropout1 = layers.Dropout(dropout_rate)\n",
" self.dropout2 = layers.Dropout(dropout_rate)\n",
"\n",
" def call(self, inputs, training):\n",
" def call(self, inputs, training=False):\n",
" attn_output = self.att(inputs, inputs)\n",
" attn_output = self.dropout1(attn_output, training=training)\n",
" out1 = self.layernorm1(inputs + attn_output)\n",
Expand Down
78 changes: 36 additions & 42 deletions examples/nlp/md/text_classification_with_switch_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ model for demonstration purposes.


```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import ops
from keras import layers
```

---
Expand All @@ -55,8 +55,6 @@ x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)

<div class="k-default-codeblock">
```
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17464789/17464789 [==============================] - 1s 0us/step
25000 Training sequences
25000 Validation sequences
Expand Down Expand Up @@ -102,8 +100,8 @@ class TokenAndPositionEmbedding(layers.Layer):
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
maxlen = ops.shape(x)[-1]
positions = ops.arange(start=0, stop=maxlen, step=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
Expand Down Expand Up @@ -138,19 +136,17 @@ def load_balanced_loss(router_probs, expert_mask):
# each expert per token. expert_mask [tokens_per_batch, num_experts] contains
# the expert with the highest router probability in one−hot format.

num_experts = tf.shape(expert_mask)[-1]
num_experts = ops.shape(expert_mask)[-1]
# Get the fraction of tokens routed to each expert.
# density is a vector of length num experts that sums to 1.
density = tf.reduce_mean(expert_mask, axis=0)
density = ops.mean(expert_mask, axis=0)
# Get fraction of probability mass assigned to each expert from the router
# across all tokens. density_proxy is a vector of length num experts that sums to 1.
density_proxy = tf.reduce_mean(router_probs, axis=0)
density_proxy = ops.mean(router_probs, axis=0)
# Want both vectors to have uniform allocation (1/num experts) across all
# num_expert elements. The two vectors will be pushed towards uniform allocation
# when the dot product is minimized.
loss = tf.reduce_mean(density_proxy * density) * tf.cast(
(num_experts**2), tf.dtypes.float32
)
loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")
return loss

```
Expand All @@ -174,47 +170,45 @@ class Router(layers.Layer):

if training:
# Add noise for exploration across experts.
router_logits += tf.random.uniform(
router_logits += keras.random.uniform(
shape=router_logits.shape, minval=0.9, maxval=1.1
)
# Probabilities for each token of what expert it should be sent to.
router_probs = keras.activations.softmax(router_logits, axis=-1)
# Get the top−1 expert for each token. expert_gate is the top−1 probability
# from the router for each token. expert_index is what expert each token
# is going to be routed to.
expert_gate, expert_index = tf.math.top_k(router_probs, k=1)
expert_gate, expert_index = ops.top_k(router_probs, k=1)
# expert_mask shape: [tokens_per_batch, num_experts]
expert_mask = tf.one_hot(expert_index, depth=self.num_experts)
expert_mask = ops.one_hot(expert_index, self.num_experts)
# Compute load balancing loss.
aux_loss = load_balanced_loss(router_probs, expert_mask)
self.add_loss(aux_loss)
# Experts have a fixed capacity, ensure we do not exceed it. Construct
# the batch indices, to each expert, with position in expert make sure that
# not more that expert capacity examples can be routed to each expert.
position_in_expert = tf.cast(
tf.math.cumsum(expert_mask, axis=0) * expert_mask, tf.dtypes.int32
position_in_expert = ops.cast(
ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
)
# Keep only tokens that fit within expert capacity.
expert_mask *= tf.cast(
tf.math.less(
tf.cast(position_in_expert, tf.dtypes.int32), self.expert_capacity
),
tf.dtypes.float32,
expert_mask *= ops.cast(
ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
"float32",
)
expert_mask_flat = tf.reduce_sum(expert_mask, axis=-1)
expert_mask_flat = ops.sum(expert_mask, axis=-1)
# Mask out the experts that have overflowed the expert capacity.
expert_gate *= expert_mask_flat
# Combine expert outputs and scaling with router probability.
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
combined_tensor = tf.expand_dims(
combined_tensor = ops.expand_dims(
expert_gate
* expert_mask_flat
* tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = tf.cast(combined_tensor, tf.dtypes.float32)
dispatch_tensor = ops.cast(combined_tensor, "float32")

return dispatch_tensor, combined_tensor

Expand All @@ -240,33 +234,33 @@ class Switch(layers.Layer):
super().__init__()

def call(self, inputs):
batch_size = tf.shape(inputs)[0]
num_tokens_per_example = tf.shape(inputs)[1]
batch_size = ops.shape(inputs)[0]
num_tokens_per_example = ops.shape(inputs)[1]

# inputs shape: [num_tokens_per_batch, embed_dim]
inputs = tf.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
expert_inputs = tf.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = tf.reshape(
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = ops.reshape(
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
)
# Dispatch to experts
expert_input_list = tf.unstack(expert_inputs, axis=0)
expert_input_list = ops.unstack(expert_inputs, axis=0)
expert_output_list = [
self.experts[idx](expert_input)
for idx, expert_input in enumerate(expert_input_list)
]
# expert_outputs shape: [expert_capacity, num_experts, embed_dim]
expert_outputs = tf.stack(expert_output_list, axis=1)
expert_outputs = ops.stack(expert_output_list, axis=1)
# expert_outputs_combined shape: [tokens_per_batch, embed_dim]
expert_outputs_combined = tf.einsum(
expert_outputs_combined = ops.einsum(
"abc,xba->xc", expert_outputs, combine_tensor
)
# output shape: [batch_size, num_tokens_per_example, embed_dim]
outputs = tf.reshape(
outputs = ops.reshape(
expert_outputs_combined,
[batch_size, num_tokens_per_example, self.embed_dim],
)
Expand All @@ -292,7 +286,7 @@ class TransformerBlock(layers.Layer):
self.dropout1 = layers.Dropout(dropout_rate)
self.dropout2 = layers.Dropout(dropout_rate)

def call(self, inputs, training):
def call(self, inputs, training=False):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
Expand Down Expand Up @@ -363,13 +357,13 @@ run_experiment(classifier)
<div class="k-default-codeblock">
```
Epoch 1/3
500/500 [==============================] - 645s 1s/step - loss: 1.4064 - accuracy: 0.8070 - val_loss: 1.3201 - val_accuracy: 0.8642
500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 - val_accuracy: 0.8748 - val_loss: 1.2891
Epoch 2/3
500/500 [==============================] - 625s 1s/step - loss: 1.2073 - accuracy: 0.9218 - val_loss: 1.3140 - val_accuracy: 0.8713
500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 - val_accuracy: 0.8752 - val_loss: 1.3090
Epoch 3/3
500/500 [==============================] - 637s 1s/step - loss: 1.1428 - accuracy: 0.9494 - val_loss: 1.3530 - val_accuracy: 0.8618
500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 - val_accuracy: 0.8614 - val_loss: 1.3744
<keras.src.callbacks.History at 0x136fb5450>
<keras.src.callbacks.history.History at 0x7efb79d82a90>
```
</div>
Expand Down
Loading

0 comments on commit 7646708

Please sign in to comment.