Skip to content

Commit

Permalink
Update docs on use of TF2 SavedModels in distributed Estimators.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 294610969
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed Feb 18, 2020
1 parent 16ae221 commit d9a850a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
22 changes: 22 additions & 0 deletions docs/migration_tf2.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ Many tutorials show these APIs in action. See in particular
If the hub.Module you use has a newer version that comes in the TF2 SavedModel
format, we recommend to switch the API and the module version at the same time.

### Using the new API in Estimator training

If you use a TF2 SavedModel in an Estimator for training with parameter servers
(or otherwise in a TF1 Session with variables placed on remote devices),
you need to set `experimental.share_cluster_devices_in_session` in the
tf.Session's ConfigProto, or else you will get an error like
"Assigned device '/job:ps/replica:0/task:0/device:CPU:0'
does not match any device."

The necessary option can be set like

```python
session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)
```

Starting with TF2.2, this option is no longer experimental, and
the `.experimental` piece can be dropped.


## Loading legacy hub.Modules

It can happen that a new TF2 SavedModel is not yet available for your
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_hub/keras_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class KerasLayer(tf.keras.layers.Layer):
for guidance on how to pick up trainable variables, losses and updates
explicitly from Keras objects instead of relying on graph collections.
This layer class does not support graph collections.
Distributed training of the Estimator requires setting the option
`session_config.experimental.share_cluster_devices_in_session` within
the `tf.estimator.RunConfig`. (It becomes non-experimental in TF2.2.)
Note: The data types used by a saved model have been fixed at saving time.
Using tf.keras.mixed_precision etc. has no effect on the saved model
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_hub/module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def load(handle, tags=None):
on the result of `hub.resolve(handle)`. Calling this function requires
TF 1.14 or newer. It can be called both in eager and graph mode.
Note: Using in a tf.compat.v1.Session with variables placed on parameter
servers requires setting `experimental.share_cluster_devices_in_session`
within the `tf.compat.v1.ConfigProto`. (It becomes non-experimental in TF2.2.)
This function can handle the deprecated hub.Module format to the extent
that `tf.save_model.load()` in TF2 does. In particular, the returned object
has attributes
Expand Down

0 comments on commit d9a850a

Please sign in to comment.