Skip to content

Commit

Permalink
Update preprocessing layers guide
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 20, 2021
1 parent 92d0a93 commit f274756
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 64 deletions.
42 changes: 24 additions & 18 deletions guides/ipynb/preprocessing_layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"source": [
"## The `adapt()` method\n",
"\n",
"Some preprocessing layers have an internal state that must be computed based on\n",
"Some preprocessing layers have an internal state that can be computed based on\n",
"a sample of the training data. The list of stateful preprocessing layers is:\n",
"\n",
"- `TextVectorization`: holds a mapping between string tokens and integer indices\n",
Expand All @@ -104,7 +104,8 @@
"- `Discretization`: holds information about value bucket boundaries.\n",
"\n",
"Crucially, these layers are **non-trainable**. Their state is not set during training; it\n",
"must be set **before training**, a step called \"adaptation\".\n",
"must be set **before training**, either by initializing them from a precomputed constant,\n",
"or by \"adapting\" them on data.\n",
"\n",
"You set the state of a preprocessing layer by exposing it to training data, via the\n",
"`adapt()` method:"
Expand Down Expand Up @@ -225,12 +226,19 @@
"batches of preprocessed data, like this:\n",
"\n",
"```python\n",
"dataset = dataset.map(\n",
" lambda x, y: (preprocessing_layer(x), y))\n",
"dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))\n",
"```\n",
"\n",
"With this option, your preprocessing will happen on CPU, asynchronously, and will be\n",
"buffered before going into the model.\n",
"In addition, if you call `dataset.prefetch(tf.data.AUTOTUNE)` on your dataset,\n",
"the preprocessing will happen efficiently in parallel with training:\n",
"\n",
"```python\n",
"dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))\n",
"dataset = dataset.prefetch(tf.data.AUTOTUNE)\n",
"model.fit(dataset, ...)\n",
"```\n",
"\n",
"This is the best option for `TextVectorization`, and all structured data preprocessing\n",
"layers. It can also be a good option if you're training on CPU\n",
Expand Down Expand Up @@ -401,7 +409,7 @@
"data = tf.constant([[\"a\"], [\"b\"], [\"c\"], [\"b\"], [\"c\"], [\"a\"]])\n",
"\n",
"# Use StringLookup to build an index of the feature values and encode output.\n",
"lookup = preprocessing.StringLookup(output_mode=\"binary\")\n",
"lookup = preprocessing.StringLookup(output_mode=\"one_hot\")\n",
"lookup.adapt(data)\n",
"\n",
"# Convert new test data (which includes unknown feature values)\n",
Expand All @@ -416,10 +424,8 @@
"colab_type": "text"
},
"source": [
"Note that index 0 is reserved for missing values (which you should specify as the empty\n",
"string `\"\"`), and index 1 is reserved for out-of-vocabulary values (values that were not\n",
"seen during `adapt()`). You can configure this by using the `mask_token` and `oov_token`\n",
"constructor arguments of `StringLookup`.\n",
"Note that, here, index 0 is reserved for out-of-vocabulary values\n",
"(values that were not seen during `adapt()`).\n",
"\n",
"You can see the `StringLookup` in action in the\n",
"[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/)\n",
Expand Down Expand Up @@ -447,7 +453,7 @@
"data = tf.constant([[10], [20], [20], [10], [30], [0]])\n",
"\n",
"# Use IntegerLookup to build an index of the feature values and encode output.\n",
"lookup = preprocessing.IntegerLookup(output_mode=\"multi_hot\")\n",
"lookup = preprocessing.IntegerLookup(output_mode=\"one_hot\")\n",
"lookup.adapt(data)\n",
"\n",
"# Convert new test data (which includes unknown feature values)\n",
Expand Down Expand Up @@ -501,7 +507,7 @@
"# Use the Hashing layer to hash the values to the range [0, 64]\n",
"hasher = preprocessing.Hashing(num_bins=64, salt=1337)\n",
"\n",
"# Use the CategoryEncoding layer to one-hot encode the hashed values\n",
"# Use the CategoryEncoding layer to multi-hot encode the hashed values\n",
"encoder = preprocessing.CategoryEncoding(num_tokens=64, output_mode=\"multi_hot\")\n",
"encoded_data = encoder(hasher(data))\n",
"print(encoded_data.shape)"
Expand Down Expand Up @@ -555,11 +561,11 @@
"\n",
"# Create a labeled dataset (which includes unknown tokens)\n",
"train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" ([\"\\nThe Brain is deeper than the sea\"], [1])\n",
" ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
")\n",
"\n",
"# Preprocess the string inputs, turning them into int sequences\n",
"train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))\n",
"train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
"# Train the model on the int sequences\n",
"print(\"\\nTraining model...\")\n",
"model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
Expand Down Expand Up @@ -620,7 +626,7 @@
" \"With ease and You beside\",\n",
" ]\n",
")\n",
"# Instantiate TextVectorization with \"binary\" output_mode (multi-hot)\n",
"# Instantiate TextVectorization with \"multi_hot\" output_mode\n",
"# and ngrams=2 (index all bigrams)\n",
"text_vectorizer = preprocessing.TextVectorization(output_mode=\"multi_hot\", ngrams=2)\n",
"# Index the bigrams via `adapt()`\n",
Expand All @@ -638,11 +644,11 @@
"\n",
"# Create a labeled dataset (which includes unknown tokens)\n",
"train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" ([\"\\nThe Brain is deeper than the sea\"], [1])\n",
" ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
")\n",
"\n",
"# Preprocess the string inputs, turning them into int sequences\n",
"train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))\n",
"train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
"# Train the model on the int sequences\n",
"print(\"\\nTraining model...\")\n",
"model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
Expand Down Expand Up @@ -707,11 +713,11 @@
"\n",
"# Create a labeled dataset (which includes unknown tokens)\n",
"train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" ([\"\\nThe Brain is deeper than the sea\"], [1])\n",
" ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
")\n",
"\n",
"# Preprocess the string inputs, turning them into int sequences\n",
"train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))\n",
"train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
"# Train the model on the int sequences\n",
"print(\"\\nTraining model...\")\n",
"model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
Expand Down
62 changes: 34 additions & 28 deletions guides/md/preprocessing_layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ are only active during training.
---
## The `adapt()` method

Some preprocessing layers have an internal state that must be computed based on
Some preprocessing layers have an internal state that can be computed based on
a sample of the training data. The list of stateful preprocessing layers is:

- `TextVectorization`: holds a mapping between string tokens and integer indices
Expand All @@ -83,7 +83,8 @@ indices.
- `Discretization`: holds information about value bucket boundaries.

Crucially, these layers are **non-trainable**. Their state is not set during training; it
must be set **before training**, a step called "adaptation".
must be set **before training**, either by initializing them from a precomputed constant,
or by "adapting" them on data.

You set the state of a preprocessing layer by exposing it to training data, via the
`adapt()` method:
Expand Down Expand Up @@ -195,12 +196,19 @@ all image preprocessing and data augmentation layers.
batches of preprocessed data, like this:

```python
dataset = dataset.map(
lambda x, y: (preprocessing_layer(x), y))
dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))
```

With this option, your preprocessing will happen on CPU, asynchronously, and will be
buffered before going into the model.
In addition, if you call `dataset.prefetch(tf.data.AUTOTUNE)` on your dataset,
the preprocessing will happen efficiently in parallel with training:

```python
dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))
dataset = dataset.prefetch(tf.data.AUTOTUNE)
model.fit(dataset, ...)
```

This is the best option for `TextVectorization`, and all structured data preprocessing
layers. It can also be a good option if you're training on CPU
Expand Down Expand Up @@ -285,9 +293,9 @@ model.fit(train_dataset, steps_per_epoch=5)

<div class="k-default-codeblock">
```
5/5 [==============================] - 10s 514ms/step - loss: 9.2905
5/5 [==============================] - 12s 576ms/step - loss: 8.7839
<tensorflow.python.keras.callbacks.History at 0x15e1664d0>
<tensorflow.python.keras.callbacks.History at 0x1550b7110>
```
</div>
Expand Down Expand Up @@ -321,9 +329,9 @@ model.fit(x_train, y_train)

<div class="k-default-codeblock">
```
1563/1563 [==============================] - 1s 835us/step - loss: 2.1309
1563/1563 [==============================] - 2s 910us/step - loss: 2.1323
<tensorflow.python.keras.callbacks.History at 0x15fce2710>
<tensorflow.python.keras.callbacks.History at 0x15727b8d0>
```
</div>
Expand All @@ -335,7 +343,7 @@ model.fit(x_train, y_train)
data = tf.constant([["a"], ["b"], ["c"], ["b"], ["c"], ["a"]])

# Use StringLookup to build an index of the feature values and encode output.
lookup = preprocessing.StringLookup(output_mode="binary")
lookup = preprocessing.StringLookup(output_mode="one_hot")
lookup.adapt(data)

# Convert new test data (which includes unknown feature values)
Expand All @@ -356,10 +364,8 @@ tf.Tensor(
```
</div>
Note that index 0 is reserved for missing values (which you should specify as the empty
string `""`), and index 1 is reserved for out-of-vocabulary values (values that were not
seen during `adapt()`). You can configure this by using the `mask_token` and `oov_token`
constructor arguments of `StringLookup`.
Note that, here, index 0 is reserved for out-of-vocabulary values
(values that were not seen during `adapt()`).

You can see the `StringLookup` in action in the
[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/)
Expand All @@ -373,7 +379,7 @@ example.
data = tf.constant([[10], [20], [20], [10], [30], [0]])

# Use IntegerLookup to build an index of the feature values and encode output.
lookup = preprocessing.IntegerLookup(output_mode="multi_hot")
lookup = preprocessing.IntegerLookup(output_mode="one_hot")
lookup.adapt(data)

# Convert new test data (which includes unknown feature values)
Expand Down Expand Up @@ -419,7 +425,7 @@ data = np.random.randint(0, 100000, size=(10000, 1))
# Use the Hashing layer to hash the values to the range [0, 64]
hasher = preprocessing.Hashing(num_bins=64, salt=1337)

# Use the CategoryEncoding layer to one-hot encode the hashed values
# Use the CategoryEncoding layer to multi-hot encode the hashed values
encoder = preprocessing.CategoryEncoding(num_tokens=64, output_mode="multi_hot")
encoded_data = encoder(hasher(data))
print(encoded_data.shape)
Expand Down Expand Up @@ -466,11 +472,11 @@ model = keras.Model(inputs, outputs)

# Create a labeled dataset (which includes unknown tokens)
train_dataset = tf.data.Dataset.from_tensor_slices(
(["\nThe Brain is deeper than the sea"], [1])
(["The Brain is deeper than the sea", "for if they are held Blue to Blue"], [1, 0])
)

# Preprocess the string inputs, turning them into int sequences
train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))
train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))
# Train the model on the int sequences
print("\nTraining model...")
model.compile(optimizer="rmsprop", loss="mse")
Expand Down Expand Up @@ -499,14 +505,14 @@ Encoded text:
<div class="k-default-codeblock">
```
Training model...
1/1 [==============================] - 1s 1s/step - loss: 0.9776
1/1 [==============================] - 2s 2s/step - loss: 0.5292
```
</div>

<div class="k-default-codeblock">
```
Calling end-to-end model on test string...
Model output: tf.Tensor([[0.04514679]], shape=(1, 1), dtype=float32)
Model output: tf.Tensor([[0.01447889]], shape=(1, 1), dtype=float32)
```
</div>
Expand All @@ -532,7 +538,7 @@ adapt_data = tf.constant(
"With ease and You beside",
]
)
# Instantiate TextVectorization with "binary" output_mode (multi-hot)
# Instantiate TextVectorization with "multi_hot" output_mode
# and ngrams=2 (index all bigrams)
text_vectorizer = preprocessing.TextVectorization(output_mode="multi_hot", ngrams=2)
# Index the bigrams via `adapt()`
Expand All @@ -550,11 +556,11 @@ model = keras.Model(inputs, outputs)

# Create a labeled dataset (which includes unknown tokens)
train_dataset = tf.data.Dataset.from_tensor_slices(
(["\nThe Brain is deeper than the sea"], [1])
(["The Brain is deeper than the sea", "for if they are held Blue to Blue"], [1, 0])
)

# Preprocess the string inputs, turning them into int sequences
train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))
train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))
# Train the model on the int sequences
print("\nTraining model...")
model.compile(optimizer="rmsprop", loss="mse")
Expand Down Expand Up @@ -584,14 +590,14 @@ Encoded text:
<div class="k-default-codeblock">
```
Training model...
1/1 [==============================] - 0s 183ms/step - loss: 2.6441
1/1 [==============================] - 0s 222ms/step - loss: 1.4333
```
</div>

<div class="k-default-codeblock">
```
Calling end-to-end model on test string...
Model output: tf.Tensor([[-1.207074]], shape=(1, 1), dtype=float32)
Model output: tf.Tensor([[-0.89536154]], shape=(1, 1), dtype=float32)
```
</div>
Expand Down Expand Up @@ -628,11 +634,11 @@ model = keras.Model(inputs, outputs)

# Create a labeled dataset (which includes unknown tokens)
train_dataset = tf.data.Dataset.from_tensor_slices(
(["\nThe Brain is deeper than the sea"], [1])
(["The Brain is deeper than the sea", "for if they are held Blue to Blue"], [1, 0])
)

# Preprocess the string inputs, turning them into int sequences
train_dataset = train_dataset.batch(1).map(lambda x, y: (text_vectorizer(x), y))
train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))
# Train the model on the int sequences
print("\nTraining model...")
model.compile(optimizer="rmsprop", loss="mse")
Expand Down Expand Up @@ -667,14 +673,14 @@ Encoded text:
<div class="k-default-codeblock">
```
Training model...
1/1 [==============================] - 0s 184ms/step - loss: 0.7904
1/1 [==============================] - 0s 216ms/step - loss: 19.4452
```
</div>

<div class="k-default-codeblock">
```
Calling end-to-end model on test string...
Model output: tf.Tensor([[0.8694465]], shape=(1, 1), dtype=float32)
Model output: tf.Tensor([[-0.36555034]], shape=(1, 1), dtype=float32)
```
</div>
Expand Down
Loading

0 comments on commit f274756

Please sign in to comment.