Skip to content

Commit c20abdb

Browse files
Johannes Ballécopybara-github
authored andcommitted
Enables mixed precision training.
- Layer classes now use `variable_dtype` for variables, and `compute_dtype` for computation, as laid out in https://www.tensorflow.org/guide/mixed_precision. - `Parameter` classes use the dtype passed to `__init__` for creating the variable, and the dtype optionally passed to `__call__` for transforming the parameter. - In entropy models, the `dtype` argument is dropped, and they now define a `bottleneck_dtype` argument giving the dtype of the bottleneck, which defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`. This is consistent with Keras and if not using mixed precision, defaults to `tf.keras.backend.floatx()`, which in turn is `tf.float32` by default. - The dtype of the prior and any probability computations is kept separate from all of the above. The batched models take the dtype for that directly from the distribution object. Indexed models have a new argument `prior_dtype`, which is used to instantiate the prior for any computations. Both this and the dtype of `DeepFactorized` default to `tf.float32`. PiperOrigin-RevId: 427359879 Change-Id: Ie163c80253b391641e7537034516f9e4d1ebe36d
1 parent 61602f0 commit c20abdb

13 files changed

+274
-132
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self,
4343
stateless=False,
4444
expected_grads=False,
4545
tail_mass=2**-8,
46-
dtype=None,
46+
bottleneck_dtype=None,
4747
laplace_tail_mass=0):
4848
"""Initializes the instance.
4949
@@ -66,8 +66,8 @@ def __init__(self,
6666
backpropagation w.r.t. additive uniform noise.
6767
tail_mass: Float. Approximate probability mass which is encoded using an
6868
Elias gamma code embedded into the range coder.
69-
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
70-
prior, decompressed values).
69+
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
70+
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
7171
laplace_tail_mass: Float. If non-zero, will augment the prior with a
7272
Laplace mixture for training stability. (experimental)
7373
"""
@@ -78,7 +78,9 @@ def __init__(self,
7878
self._stateless = bool(stateless)
7979
self._expected_grads = bool(expected_grads)
8080
self._tail_mass = float(tail_mass)
81-
self._dtype = tf.as_dtype(dtype)
81+
if bottleneck_dtype is None:
82+
bottleneck_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
83+
self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype)
8284
self._laplace_tail_mass = float(laplace_tail_mass)
8385

8486
if self.coding_rank < 0:
@@ -88,10 +90,6 @@ def __init__(self,
8890
if not 0 <= self.laplace_tail_mass < 1:
8991
raise ValueError("`laplace_tail_mass` must be between 0 and 1.")
9092

91-
with self.name_scope:
92-
self._laplace_prior = (tfp.distributions.Laplace(loc=0., scale=1.)
93-
if laplace_tail_mass else None)
94-
9593
def _check_compression(self):
9694
if not self.compression:
9795
raise RuntimeError(
@@ -123,9 +121,9 @@ def cdf_offset(self):
123121
return tf.convert_to_tensor(self._cdf_offset)
124122

125123
@property
126-
def dtype(self):
127-
"""Data type of this entropy model."""
128-
return self._dtype
124+
def bottleneck_dtype(self):
125+
"""Data type of the bottleneck tensor."""
126+
return self._bottleneck_dtype
129127

130128
@property
131129
def expected_grads(self):
@@ -247,7 +245,7 @@ def _build_tables(self, prior, precision, offset=None):
247245
maxima = tf.cast(tf.math.ceil(upper_tail - offset), tf.int32)
248246

249247
# PMF starting positions and lengths.
250-
pmf_start = tf.cast(minima, self.dtype) + offset
248+
pmf_start = tf.cast(minima, prior.dtype) + offset
251249
pmf_length = maxima - minima + 1
252250

253251
# Sample the densities in the computed ranges, possibly computing more
@@ -258,7 +256,7 @@ def _build_tables(self, prior, precision, offset=None):
258256
"Very wide PMF with %d elements may lead to out of memory issues. "
259257
"Consider priors with smaller variance, or increasing `tail_mass` "
260258
"parameter.", int(max_length))
261-
samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
259+
samples = tf.range(tf.cast(max_length, prior.dtype), dtype=prior.dtype)
262260
samples = tf.reshape(samples, [-1] + pmf_length.shape.rank * [1])
263261
samples += pmf_start
264262
pmf = prior.prob(samples)
@@ -294,8 +292,11 @@ def loop_body(i, cdf):
294292

295293
def _log_prob(self, prior, bottleneck_perturbed):
296294
"""Evaluates prior.log_prob(bottleneck + noise)."""
295+
bottleneck_perturbed = tf.cast(bottleneck_perturbed, prior.dtype)
297296
if self.laplace_tail_mass:
298-
laplace_prior = self._laplace_prior
297+
laplace_prior = tfp.distributions.Laplace(
298+
loc=tf.constant(0, dtype=prior.dtype),
299+
scale=tf.constant(1, dtype=prior.dtype))
299300
probs = prior.prob(bottleneck_perturbed)
300301
probs = ((1 - self.laplace_tail_mass) * probs +
301302
self.laplace_tail_mass *
@@ -332,7 +333,7 @@ def get_config(self):
332333
expected_grads=self.expected_grads,
333334
tail_mass=self.tail_mass,
334335
cdf_shapes=(self.cdf.shape[0], self.cdf_offset.shape[0]),
335-
dtype=self.dtype.name,
336+
bottleneck_dtype=self.bottleneck_dtype.name,
336337
laplace_tail_mass=self.laplace_tail_mass,
337338
)
338339

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self,
115115
expected_grads=False,
116116
tail_mass=2**-8,
117117
range_coder_precision=12,
118-
dtype=None,
118+
bottleneck_dtype=None,
119119
prior_shape=None,
120120
cdf=None,
121121
cdf_offset=None,
@@ -153,8 +153,8 @@ def __init__(self,
153153
tail_mass: Float. Approximate probability mass which is encoded using an
154154
Elias gamma code embedded into the range coder.
155155
range_coder_precision: Integer. Precision passed to the range coding op.
156-
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
157-
prior, decompressed values). Must be provided if `prior` is omitted.
156+
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
157+
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
158158
prior_shape: Batch shape of the prior (dimensions which are not assumed
159159
i.i.d.). Must be provided if `prior` is omitted.
160160
cdf: `tf.Tensor` or `None`. If provided, is used for range coding rather
@@ -171,9 +171,8 @@ def __init__(self,
171171
laplace_tail_mass: Float. If positive, will augment the prior with a
172172
Laplace mixture for training stability. (experimental)
173173
"""
174-
if not (prior is not None) == (dtype is None) == (prior_shape is None):
175-
raise ValueError(
176-
"Either `prior` or both `dtype` and `prior_shape` must be provided.")
174+
if (prior is None) == (prior_shape is None):
175+
raise ValueError("Either `prior` or `prior_shape` must be provided.")
177176
if (prior is None) + (cdf_shapes is None) + (cdf is None) != 2:
178177
raise ValueError(
179178
"Must provide exactly one of `prior`, `cdf`, or `cdf_shapes`.")
@@ -189,7 +188,7 @@ def __init__(self,
189188
stateless=stateless,
190189
expected_grads=expected_grads,
191190
tail_mass=tail_mass,
192-
dtype=dtype if dtype is not None else prior.dtype,
191+
bottleneck_dtype=bottleneck_dtype,
193192
laplace_tail_mass=laplace_tail_mass,
194193
)
195194
self._prior = prior
@@ -209,8 +208,7 @@ def __init__(self,
209208
assert isinstance(quantization_offset, bool)
210209
assert self.compression
211210
if quantization_offset:
212-
quantization_offset = tf.zeros(
213-
self.prior_shape_tensor, dtype=self.dtype)
211+
quantization_offset = tf.zeros(self.prior_shape_tensor)
214212
else:
215213
quantization_offset = None
216214
elif quantization_offset is not None:
@@ -236,12 +234,15 @@ def __init__(self,
236234
if quantization_offset is None:
237235
self._quantization_offset = None
238236
elif self.compression and not self.stateless:
237+
quantization_offset = tf.cast(
238+
quantization_offset, self.bottleneck_dtype)
239239
self._quantization_offset = tf.Variable(
240-
quantization_offset, dtype=self.dtype, trainable=False,
241-
name="quantization_offset")
240+
quantization_offset, trainable=False, name="quantization_offset")
242241
else:
242+
quantization_offset = tf.cast(
243+
quantization_offset, self.bottleneck_dtype)
243244
self._quantization_offset = tf.convert_to_tensor(
244-
quantization_offset, dtype=self.dtype, name="quantization_offset")
245+
quantization_offset, name="quantization_offset")
245246
if self.compression:
246247
if cdf is None and cdf_shapes is None:
247248
cdf, cdf_offset = self._build_tables(
@@ -276,7 +277,8 @@ def quantization_offset(self):
276277
"tf.function. Ideally, the offset heuristic should only be used "
277278
"to determine offsets once after training. Depending on the prior, "
278279
"estimating the offset might be computationally expensive.")
279-
return helpers.quantization_offset(self.prior)
280+
return tf.cast(
281+
helpers.quantization_offset(self.prior), self.bottleneck_dtype)
280282
return None
281283

282284
@tf.Module.with_name_scope
@@ -299,6 +301,7 @@ def __call__(self, bottleneck, training=True):
299301
`bits` has the same shape as `bottleneck` without the `self.coding_rank`
300302
innermost dimensions.
301303
"""
304+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
302305
log_prob_fn = functools.partial(self._log_prob, self.prior)
303306
if training:
304307
log_probs, bottleneck_perturbed = math_ops.perturb_and_apply(
@@ -331,6 +334,7 @@ def quantize(self, bottleneck):
331334
Returns:
332335
A `tf.Tensor` containing the quantized values.
333336
"""
337+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
334338
return round_ops.round_st(bottleneck, self.quantization_offset)
335339

336340
@tf.Module.with_name_scope
@@ -356,6 +360,7 @@ def compress(self, bottleneck):
356360
`self.coding_rank` innermost dimensions, containing a string for each
357361
coding unit.
358362
"""
363+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
359364
input_shape = tf.shape(bottleneck)
360365
all_but_last_n_elems = lambda t, n: t[:-n] if n else t
361366
batch_shape = all_but_last_n_elems(input_shape, self.coding_rank)
@@ -400,7 +405,7 @@ def decompress(self, strings, broadcast_shape):
400405
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
401406
symbols += self.cdf_offset
402407
symbols = tf.reshape(symbols, output_shape)
403-
outputs = tf.cast(symbols, self.dtype)
408+
outputs = tf.cast(symbols, self.bottleneck_dtype)
404409
offset = self.quantization_offset
405410
if offset is not None:
406411
outputs += offset

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def test_can_instantiate(self):
3030
self.assertIs(em.prior, noisy)
3131
self.assertEqual(em.coding_rank, 1)
3232
self.assertEqual(em.tail_mass, 2**-8)
33-
self.assertEqual(em.dtype, noisy.dtype)
33+
self.assertEqual(em.bottleneck_dtype, tf.float32)
34+
self.assertEqual(em.prior.dtype, tf.float32)
3435

3536
def test_can_instantiate_statelessly(self):
3637
noisy = uniform_noise.NoisyNormal(loc=.25, scale=1.)
@@ -41,8 +42,7 @@ def test_can_instantiate_statelessly(self):
4142
self.assertAllEqual(.25, em.quantization_offset)
4243
em = ContinuousBatchedEntropyModel(
4344
compression=True, stateless=True, coding_rank=1,
44-
prior_shape=noisy.batch_shape, dtype=noisy.dtype,
45-
cdf=em.cdf, cdf_offset=em.cdf_offset,
45+
prior_shape=noisy.batch_shape, cdf=em.cdf, cdf_offset=em.cdf_offset,
4646
quantization_offset=em.quantization_offset,
4747
)
4848
self.assertEqual(em.compression, True)
@@ -53,7 +53,7 @@ def test_can_instantiate_statelessly(self):
5353
self.assertEqual(em.coding_rank, 1)
5454
self.assertEqual(em.tail_mass, 2**-8)
5555
self.assertEqual(em.range_coder_precision, 12)
56-
self.assertEqual(em.dtype, noisy.dtype)
56+
self.assertEqual(em.bottleneck_dtype, tf.float32)
5757

5858
def test_requires_scalar_distributions(self):
5959
noisy = uniform_noise.UniformNoiseAdapter(
@@ -194,6 +194,29 @@ def compress(self, values):
194194
self.assertAllClose(samples, values_eager, rtol=0., atol=.5)
195195
self.assertAllEqual(values_eager, values_function)
196196

197+
def test_dtypes_are_correct_with_mixed_precision(self):
198+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
199+
try:
200+
noisy = uniform_noise.NoisyNormal(
201+
loc=tf.constant(0, dtype=tf.float64),
202+
scale=tf.constant(1, dtype=tf.float64))
203+
em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
204+
self.assertEqual(em.bottleneck_dtype, tf.float16)
205+
self.assertEqual(em.prior.dtype, tf.float64)
206+
x = tf.random.stateless_normal((2, 5), seed=(0, 1), dtype=tf.float16)
207+
x_tilde, bits = em(x)
208+
bitstring = em.compress(x)
209+
x_hat = em.decompress(bitstring, (5,))
210+
self.assertEqual(x_hat.dtype, tf.float16)
211+
self.assertAllClose(x, x_hat, rtol=0, atol=.5)
212+
self.assertEqual(x_tilde.dtype, tf.float16)
213+
self.assertAllClose(x, x_tilde, rtol=0, atol=.5)
214+
self.assertEqual(bits.dtype, tf.float64)
215+
self.assertEqual(bits.shape, (2,))
216+
self.assertAllGreaterEqual(bits, 0.)
217+
finally:
218+
tf.keras.mixed_precision.set_global_policy(None)
219+
197220
def test_small_cdfs_for_dirac_prior_without_quantization_offset(self):
198221
prior = uniform_noise.NoisyNormal(loc=100. * tf.range(16.), scale=1e-10)
199222
em = ContinuousBatchedEntropyModel(

0 commit comments

Comments
 (0)