Skip to content

Commit

Permalink
Update wide_deep_cross_networks example with tensorflow 2.5 changes (k…
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored May 7, 2021
1 parent d2fb727 commit 3551725
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 177 deletions.
27 changes: 13 additions & 14 deletions examples/structured_data/ipynb/wide_deep_cross_networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>\n",
"**Date created:** 2020/12/31<br>\n",
"**Last modified:** 2020/12/31<br>\n",
"**Last modified:** 2021/05/05<br>\n",
"**Description:** Using Wide & Deep and Deep & Cross networks for structured data classification."
]
},
Expand All @@ -28,7 +28,7 @@
"1. [Wide & Deep](https://ai.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) models\n",
"2. [Deep & Cross](https://arxiv.org/abs/1708.05123) models\n",
"\n",
"Note that this example should be run with TensorFlow 2.3 or higher."
"Note that this example should be run with TensorFlow 2.5 or higher."
]
},
{
Expand Down Expand Up @@ -419,7 +419,6 @@
"outputs": [],
"source": [
"\n",
"from tensorflow.keras.layers.experimental.preprocessing import CategoryEncoding\n",
"from tensorflow.keras.layers.experimental.preprocessing import StringLookup\n",
"\n",
"\n",
Expand All @@ -431,25 +430,25 @@
" # Create a lookup to convert string values to an integer indices.\n",
" # Since we are not using a mask token nor expecting any out of vocabulary\n",
" # (oov) token, we set mask_token to None and num_oov_indices to 0.\n",
" index = StringLookup(\n",
" vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n",
" lookup = StringLookup(\n",
" vocabulary=vocabulary,\n",
" mask_token=None,\n",
" num_oov_indices=0,\n",
" output_mode=\"int\" if use_embedding else \"binary\",\n",
" )\n",
" # Convert the string input values into integer indices.\n",
" value_index = index(inputs[feature_name])\n",
" if use_embedding:\n",
" # Convert the string input values into integer indices.\n",
" encoded_feature = lookup(inputs[feature_name])\n",
" embedding_dims = int(math.sqrt(len(vocabulary)))\n",
" # Create an embedding layer with the specified dimensions.\n",
" embedding_ecoder = layers.Embedding(\n",
" embedding = layers.Embedding(\n",
" input_dim=len(vocabulary), output_dim=embedding_dims\n",
" )\n",
" # Convert the index values to embedding representations.\n",
" encoded_feature = embedding_ecoder(value_index)\n",
" encoded_feature = embedding(encoded_feature)\n",
" else:\n",
" # Create a one-hot encoder.\n",
" onehot_encoder = CategoryEncoding(output_mode=\"binary\")\n",
" onehot_encoder.adapt(index(vocabulary))\n",
" # Convert the index values to a one-hot representation.\n",
" encoded_feature = onehot_encoder(value_index)\n",
" # Convert the string input values into a one hot encoding.\n",
" encoded_feature = lookup(tf.expand_dims(inputs[feature_name], -1))\n",
" else:\n",
" # Use the numerical features as-is.\n",
" encoded_feature = tf.expand_dims(inputs[feature_name], -1)\n",
Expand Down
Loading

0 comments on commit 3551725

Please sign in to comment.