You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- Layer classes now use `variable_dtype` for variables, and `compute_dtype` for
computation, as laid out in https://www.tensorflow.org/guide/mixed_precision.
- `Parameter` classes use the dtype passed to `__init__` for creating the
variable, and the dtype optionally passed to `__call__` for transforming the
parameter.
- In entropy models, the `dtype` argument is dropped, and they now define a
`bottleneck_dtype` argument giving the dtype of the bottleneck, which defaults
to `tf.keras.mixed_precision.global_policy().compute_dtype`. This is
consistent with Keras and if not using mixed precision, defaults to
`tf.keras.backend.floatx()`, which in turn is `tf.float32` by default.
- The dtype of the prior and any probability computations is kept separate from
all of the above. The batched models take the dtype for that directly from the
distribution object. Indexed models have a new argument `prior_dtype`, which
is used to instantiate the prior for any computations. Both this and the
dtype of `DeepFactorized` default to `tf.float32`.
PiperOrigin-RevId: 427359879
Change-Id: Ie163c80253b391641e7537034516f9e4d1ebe36d
0 commit comments