Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update "Text classification with switch transformer" to use Keras 3 #1678

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
"import keras\n",
"from keras import ops\n",
"from keras import layers"
]
},
{
Expand Down Expand Up @@ -145,8 +145,8 @@
" self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
"\n",
" def call(self, x):\n",
" maxlen = tf.shape(x)[-1]\n",
" positions = tf.range(start=0, limit=maxlen, delta=1)\n",
" maxlen = ops.shape(x)[-1]\n",
" positions = ops.arange(start=0, stop=maxlen, step=1)\n",
" positions = self.pos_emb(positions)\n",
" x = self.token_emb(x)\n",
" return x + positions\n",
Expand Down Expand Up @@ -205,19 +205,17 @@
" # each expert per token. expert_mask [tokens_per_batch, num_experts] contains\n",
" # the expert with the highest router probability in one\u2212hot format.\n",
"\n",
" num_experts = tf.shape(expert_mask)[-1]\n",
" num_experts = ops.shape(expert_mask)[-1]\n",
" # Get the fraction of tokens routed to each expert.\n",
" # density is a vector of length num experts that sums to 1.\n",
" density = tf.reduce_mean(expert_mask, axis=0)\n",
" density = ops.mean(expert_mask, axis=0)\n",
" # Get fraction of probability mass assigned to each expert from the router\n",
" # across all tokens. density_proxy is a vector of length num experts that sums to 1.\n",
" density_proxy = tf.reduce_mean(router_probs, axis=0)\n",
" density_proxy = ops.mean(router_probs, axis=0)\n",
" # Want both vectors to have uniform allocation (1/num experts) across all\n",
" # num_expert elements. The two vectors will be pushed towards uniform allocation\n",
" # when the dot product is minimized.\n",
" loss = tf.reduce_mean(density_proxy * density) * tf.cast(\n",
" (num_experts**2), tf.dtypes.float32\n",
" )\n",
" loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), \"float32\")\n",
" return loss\n",
""
]
Expand Down Expand Up @@ -254,47 +252,45 @@
"\n",
" if training:\n",
" # Add noise for exploration across experts.\n",
" router_logits += tf.random.uniform(\n",
" router_logits += keras.random.uniform(\n",
" shape=router_logits.shape, minval=0.9, maxval=1.1\n",
" )\n",
" # Probabilities for each token of what expert it should be sent to.\n",
" router_probs = keras.activations.softmax(router_logits, axis=-1)\n",
" # Get the top\u22121 expert for each token. expert_gate is the top\u22121 probability\n",
" # from the router for each token. expert_index is what expert each token\n",
" # is going to be routed to.\n",
" expert_gate, expert_index = tf.math.top_k(router_probs, k=1)\n",
" expert_gate, expert_index = ops.top_k(router_probs, k=1)\n",
" # expert_mask shape: [tokens_per_batch, num_experts]\n",
" expert_mask = tf.one_hot(expert_index, depth=self.num_experts)\n",
" expert_mask = ops.one_hot(expert_index, self.num_experts)\n",
" # Compute load balancing loss.\n",
" aux_loss = load_balanced_loss(router_probs, expert_mask)\n",
" self.add_loss(aux_loss)\n",
" # Experts have a fixed capacity, ensure we do not exceed it. Construct\n",
" # the batch indices, to each expert, with position in expert make sure that\n",
" # not more that expert capacity examples can be routed to each expert.\n",
" position_in_expert = tf.cast(\n",
" tf.math.cumsum(expert_mask, axis=0) * expert_mask, tf.dtypes.int32\n",
" position_in_expert = ops.cast(\n",
" ops.cumsum(expert_mask, axis=0) * expert_mask, \"int32\"\n",
" )\n",
" # Keep only tokens that fit within expert capacity.\n",
" expert_mask *= tf.cast(\n",
" tf.math.less(\n",
" tf.cast(position_in_expert, tf.dtypes.int32), self.expert_capacity\n",
" ),\n",
" tf.dtypes.float32,\n",
" expert_mask *= ops.cast(\n",
" ops.less(ops.cast(position_in_expert, \"int32\"), self.expert_capacity),\n",
" \"float32\",\n",
" )\n",
" expert_mask_flat = tf.reduce_sum(expert_mask, axis=-1)\n",
" expert_mask_flat = ops.sum(expert_mask, axis=-1)\n",
" # Mask out the experts that have overflowed the expert capacity.\n",
" expert_gate *= expert_mask_flat\n",
" # Combine expert outputs and scaling with router probability.\n",
" # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" combined_tensor = tf.expand_dims(\n",
" combined_tensor = ops.expand_dims(\n",
" expert_gate\n",
" * expert_mask_flat\n",
" * tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),\n",
" * ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),\n",
" -1,\n",
" ) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)\n",
" ) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)\n",
" # Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]\n",
" # that is 1 if the token gets routed to the corresponding expert.\n",
" dispatch_tensor = tf.cast(combined_tensor, tf.dtypes.float32)\n",
" dispatch_tensor = ops.cast(combined_tensor, \"float32\")\n",
"\n",
" return dispatch_tensor, combined_tensor\n",
""
Expand Down Expand Up @@ -333,33 +329,33 @@
" super().__init__()\n",
"\n",
" def call(self, inputs):\n",
" batch_size = tf.shape(inputs)[0]\n",
" num_tokens_per_example = tf.shape(inputs)[1]\n",
" batch_size = ops.shape(inputs)[0]\n",
" num_tokens_per_example = ops.shape(inputs)[1]\n",
"\n",
" # inputs shape: [num_tokens_per_batch, embed_dim]\n",
" inputs = tf.reshape(inputs, [num_tokens_per_batch, self.embed_dim])\n",
" inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])\n",
" # dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]\n",
" # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" dispatch_tensor, combine_tensor = self.router(inputs)\n",
" # expert_inputs shape: [num_experts, expert_capacity, embed_dim]\n",
" expert_inputs = tf.einsum(\"ab,acd->cdb\", inputs, dispatch_tensor)\n",
" expert_inputs = tf.reshape(\n",
" expert_inputs = ops.einsum(\"ab,acd->cdb\", inputs, dispatch_tensor)\n",
" expert_inputs = ops.reshape(\n",
" expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]\n",
" )\n",
" # Dispatch to experts\n",
" expert_input_list = tf.unstack(expert_inputs, axis=0)\n",
" expert_input_list = ops.unstack(expert_inputs, axis=0)\n",
" expert_output_list = [\n",
" self.experts[idx](expert_input)\n",
" for idx, expert_input in enumerate(expert_input_list)\n",
" ]\n",
" # expert_outputs shape: [expert_capacity, num_experts, embed_dim]\n",
" expert_outputs = tf.stack(expert_output_list, axis=1)\n",
" expert_outputs = ops.stack(expert_output_list, axis=1)\n",
" # expert_outputs_combined shape: [tokens_per_batch, embed_dim]\n",
" expert_outputs_combined = tf.einsum(\n",
" expert_outputs_combined = ops.einsum(\n",
" \"abc,xba->xc\", expert_outputs, combine_tensor\n",
" )\n",
" # output shape: [batch_size, num_tokens_per_example, embed_dim]\n",
" outputs = tf.reshape(\n",
" outputs = ops.reshape(\n",
" expert_outputs_combined,\n",
" [batch_size, num_tokens_per_example, self.embed_dim],\n",
" )\n",
Expand Down Expand Up @@ -397,7 +393,7 @@
" self.dropout1 = layers.Dropout(dropout_rate)\n",
" self.dropout2 = layers.Dropout(dropout_rate)\n",
"\n",
" def call(self, inputs, training):\n",
" def call(self, inputs, training=False):\n",
" attn_output = self.att(inputs, inputs)\n",
" attn_output = self.dropout1(attn_output, training=training)\n",
" out1 = self.layernorm1(inputs + attn_output)\n",
Expand Down
78 changes: 36 additions & 42 deletions examples/nlp/md/text_classification_with_switch_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ model for demonstration purposes.


```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import ops
from keras import layers
```

---
Expand All @@ -55,8 +55,6 @@ x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)

<div class="k-default-codeblock">
```
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17464789/17464789 [==============================] - 1s 0us/step
25000 Training sequences
25000 Validation sequences

Expand Down Expand Up @@ -102,8 +100,8 @@ class TokenAndPositionEmbedding(layers.Layer):
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
maxlen = ops.shape(x)[-1]
positions = ops.arange(start=0, stop=maxlen, step=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
Expand Down Expand Up @@ -138,19 +136,17 @@ def load_balanced_loss(router_probs, expert_mask):
# each expert per token. expert_mask [tokens_per_batch, num_experts] contains
# the expert with the highest router probability in one−hot format.

num_experts = tf.shape(expert_mask)[-1]
num_experts = ops.shape(expert_mask)[-1]
# Get the fraction of tokens routed to each expert.
# density is a vector of length num experts that sums to 1.
density = tf.reduce_mean(expert_mask, axis=0)
density = ops.mean(expert_mask, axis=0)
# Get fraction of probability mass assigned to each expert from the router
# across all tokens. density_proxy is a vector of length num experts that sums to 1.
density_proxy = tf.reduce_mean(router_probs, axis=0)
density_proxy = ops.mean(router_probs, axis=0)
# Want both vectors to have uniform allocation (1/num experts) across all
# num_expert elements. The two vectors will be pushed towards uniform allocation
# when the dot product is minimized.
loss = tf.reduce_mean(density_proxy * density) * tf.cast(
(num_experts**2), tf.dtypes.float32
)
loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")
return loss

```
Expand All @@ -174,47 +170,45 @@ class Router(layers.Layer):

if training:
# Add noise for exploration across experts.
router_logits += tf.random.uniform(
router_logits += keras.random.uniform(
shape=router_logits.shape, minval=0.9, maxval=1.1
)
# Probabilities for each token of what expert it should be sent to.
router_probs = keras.activations.softmax(router_logits, axis=-1)
# Get the top−1 expert for each token. expert_gate is the top−1 probability
# from the router for each token. expert_index is what expert each token
# is going to be routed to.
expert_gate, expert_index = tf.math.top_k(router_probs, k=1)
expert_gate, expert_index = ops.top_k(router_probs, k=1)
# expert_mask shape: [tokens_per_batch, num_experts]
expert_mask = tf.one_hot(expert_index, depth=self.num_experts)
expert_mask = ops.one_hot(expert_index, self.num_experts)
# Compute load balancing loss.
aux_loss = load_balanced_loss(router_probs, expert_mask)
self.add_loss(aux_loss)
# Experts have a fixed capacity, ensure we do not exceed it. Construct
# the batch indices, to each expert, with position in expert make sure that
# not more that expert capacity examples can be routed to each expert.
position_in_expert = tf.cast(
tf.math.cumsum(expert_mask, axis=0) * expert_mask, tf.dtypes.int32
position_in_expert = ops.cast(
ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
)
# Keep only tokens that fit within expert capacity.
expert_mask *= tf.cast(
tf.math.less(
tf.cast(position_in_expert, tf.dtypes.int32), self.expert_capacity
),
tf.dtypes.float32,
expert_mask *= ops.cast(
ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
"float32",
)
expert_mask_flat = tf.reduce_sum(expert_mask, axis=-1)
expert_mask_flat = ops.sum(expert_mask, axis=-1)
# Mask out the experts that have overflowed the expert capacity.
expert_gate *= expert_mask_flat
# Combine expert outputs and scaling with router probability.
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
combined_tensor = tf.expand_dims(
combined_tensor = ops.expand_dims(
expert_gate
* expert_mask_flat
* tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = tf.cast(combined_tensor, tf.dtypes.float32)
dispatch_tensor = ops.cast(combined_tensor, "float32")

return dispatch_tensor, combined_tensor

Expand All @@ -240,33 +234,33 @@ class Switch(layers.Layer):
super().__init__()

def call(self, inputs):
batch_size = tf.shape(inputs)[0]
num_tokens_per_example = tf.shape(inputs)[1]
batch_size = ops.shape(inputs)[0]
num_tokens_per_example = ops.shape(inputs)[1]

# inputs shape: [num_tokens_per_batch, embed_dim]
inputs = tf.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
expert_inputs = tf.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = tf.reshape(
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = ops.reshape(
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
)
# Dispatch to experts
expert_input_list = tf.unstack(expert_inputs, axis=0)
expert_input_list = ops.unstack(expert_inputs, axis=0)
expert_output_list = [
self.experts[idx](expert_input)
for idx, expert_input in enumerate(expert_input_list)
]
# expert_outputs shape: [expert_capacity, num_experts, embed_dim]
expert_outputs = tf.stack(expert_output_list, axis=1)
expert_outputs = ops.stack(expert_output_list, axis=1)
# expert_outputs_combined shape: [tokens_per_batch, embed_dim]
expert_outputs_combined = tf.einsum(
expert_outputs_combined = ops.einsum(
"abc,xba->xc", expert_outputs, combine_tensor
)
# output shape: [batch_size, num_tokens_per_example, embed_dim]
outputs = tf.reshape(
outputs = ops.reshape(
expert_outputs_combined,
[batch_size, num_tokens_per_example, self.embed_dim],
)
Expand All @@ -292,7 +286,7 @@ class TransformerBlock(layers.Layer):
self.dropout1 = layers.Dropout(dropout_rate)
self.dropout2 = layers.Dropout(dropout_rate)

def call(self, inputs, training):
def call(self, inputs, training=False):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
Expand Down Expand Up @@ -363,13 +357,13 @@ run_experiment(classifier)
<div class="k-default-codeblock">
```
Epoch 1/3
500/500 [==============================] - 645s 1s/step - loss: 1.4064 - accuracy: 0.8070 - val_loss: 1.3201 - val_accuracy: 0.8642
500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 - val_accuracy: 0.8748 - val_loss: 1.2891
Epoch 2/3
500/500 [==============================] - 625s 1s/step - loss: 1.2073 - accuracy: 0.9218 - val_loss: 1.3140 - val_accuracy: 0.8713
500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 - val_accuracy: 0.8752 - val_loss: 1.3090
Epoch 3/3
500/500 [==============================] - 637s 1s/step - loss: 1.1428 - accuracy: 0.9494 - val_loss: 1.3530 - val_accuracy: 0.8618
500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 - val_accuracy: 0.8614 - val_loss: 1.3744

<keras.src.callbacks.History at 0x136fb5450>
<keras.src.callbacks.history.History at 0x7efb79d82a90>

```
</div>
Expand Down
Loading
Loading