Skip to content

Commit

Permalink
Fix TFLiteConverter.convert() by providing the input shape
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 305847047
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed Apr 14, 2020
1 parent d52ebe4 commit 647f758
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/colab/tf2_image_retraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"\n",
"### Looking for a tool instead?\n",
"\n",
"This is a TensorFlow coding tutorial. If you want a tool that just builds the TensorFlow or TF Lite model for, take a look at the [make_image_classifier](https://github.com/tensorflow/hub/tree/master/tensorflow_hub/tools/make_image_classifier) command-line tool that gets [installed](https://www.tensorflow.org/hub/installation) by the PIP package `tensorflow-hub[make_image_classifier]`, or at [this](https://colab.sandbox.google.com/github/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/image_classification.ipynb) TF Lite colab.\n"
"This is a TensorFlow coding tutorial. If you want a tool that just builds the TensorFlow or TF Lite model for, take a look at the [make_image_classifier](https://github.com/tensorflow/hub/tree/master/tensorflow_hub/tools/make_image_classifier) command-line tool that gets [installed](https://www.tensorflow.org/hub/installation) by the PIP package `tensorflow-hub[make_image_classifier]`, or at [this](https://colab.sandbox.google.com/github/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/demo/image_classification.ipynb) TF Lite colab.\n"
]
},
{
Expand Down Expand Up @@ -255,6 +255,9 @@
"source": [
"print(\"Building model with\", MODULE_HANDLE)\n",
"model = tf.keras.Sequential([\n",
" # Explicitly define the input shape so the model can be properly\n",
" # loaded by the TFLiteConverter\n",
" tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),\n",
" hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(train_generator.num_classes,\n",
Expand Down Expand Up @@ -302,7 +305,7 @@
"source": [
"steps_per_epoch = train_generator.samples // train_generator.batch_size\n",
"validation_steps = valid_generator.samples // valid_generator.batch_size\n",
"hist = model.fit_generator(\n",
"hist = model.fit(\n",
" train_generator,\n",
" epochs=5, steps_per_epoch=steps_per_epoch,\n",
" validation_data=valid_generator,\n",
Expand Down

0 comments on commit 647f758

Please sign in to comment.