Skip to content

Commit edb8df5

Browse files
Johannes Ballécopybara-github
authored andcommitted
Revises handling of quantization offsets.
- The quantization offset for DeepFactorized is now determined numerically instead of assuming zero. - For batched entropy models, the `non_integer_offset` argument controls whether the quantization offset heuristic is used or not (as before). - For location-scale family entropy models, always quantize to integers modulo location parameter of the prior distribution. - For general indexed entropy models, do not use quantization offset heuristic, and always quantize to integers. - Universal entropy models use their own logic, as before. The above is accomplished by some refactoring: - The logic for creating the range coding tables is moved from the initializer of the base class in continuous_base.py to the initializers of the subclasses. This makes it possible to streamline the building of the range coding tables and make that logic available as a private method to be called by subclasses instead. - Models in in universal.py now depend directly on the base class. This way, they don't need to inherit the quantization offset logic and can implement their own. Both of these changes make it possible to remove indirection. They also free parent classes from having to implement functionality they don't need, and child classes from inheriting functionality that doesn't make sense for them. PiperOrigin-RevId: 420564225 Change-Id: I57cdd9627b83db3a2455a23d9481ccb23309f957
1 parent c60a5a9 commit edb8df5

24 files changed

+826
-516
lines changed

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ py_library(
2424
"//tensorflow_compression/python/ops:gen_ops",
2525
"//tensorflow_compression/python/ops:math_ops",
2626
"//tensorflow_compression/python/ops:padding_ops",
27-
"//tensorflow_compression/python/ops:soft_round_ops",
27+
"//tensorflow_compression/python/ops:round_ops",
2828
"//tensorflow_compression/python/util:packed_tensors",
2929
],
3030
)

tensorflow_compression/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from tensorflow_compression.python.ops.gen_ops import *
3636
from tensorflow_compression.python.ops.math_ops import *
3737
from tensorflow_compression.python.ops.padding_ops import *
38-
from tensorflow_compression.python.ops.soft_round_ops import *
38+
from tensorflow_compression.python.ops.round_ops import *
3939

4040
from tensorflow_compression.python.util.packed_tensors import *
4141
# pylint: enable=wildcard-import

tensorflow_compression/all_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from tensorflow_compression.python.ops.math_ops_test import *
4141
from tensorflow_compression.python.ops.padding_ops_test import *
4242
from tensorflow_compression.python.ops.range_coding_ops_test import *
43-
from tensorflow_compression.python.ops.soft_round_ops_test import *
43+
from tensorflow_compression.python.ops.round_ops_test import *
4444

4545
from tensorflow_compression.python.util.packed_tensors_test import *
4646
# pylint: enable=wildcard-import

tensorflow_compression/python/distributions/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ py_library(
6565
":deep_factorized",
6666
":helpers",
6767
":uniform_noise",
68-
"//tensorflow_compression/python/ops:soft_round_ops",
68+
"//tensorflow_compression/python/ops:round_ops",
6969
],
7070
)
7171

@@ -76,7 +76,6 @@ py_test(
7676
deps = [
7777
":deep_factorized",
7878
":round_adapters",
79-
"//tensorflow_compression/python/ops:soft_round_ops",
8079
],
8180
)
8281

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,18 @@ def _log_prob(self, inputs):
239239
return log_s_logits + log_s_neg_logits + tf.math.log(dlogits)
240240

241241
def _quantization_offset(self):
242-
return tf.constant(0, dtype=self.dtype)
242+
return helpers.estimate_tails(
243+
self._logits_cumulative, 0., self.batch_shape_tensor(), self.dtype)
243244

244245
def _lower_tail(self, tail_mass):
245-
logits = tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
246+
logits = tf.math.log(
247+
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
246248
return helpers.estimate_tails(
247249
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
248250

249251
def _upper_tail(self, tail_mass):
250-
logits = -tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
252+
logits = -tf.math.log(
253+
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
251254
return helpers.estimate_tails(
252255
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
253256

tensorflow_compression/python/distributions/deep_factorized_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ def test_uniform_is_special_case(self):
111111
x = tf.linspace(-1., 1., 10)
112112
self.assertAllClose(df.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0])
113113

114-
def test_quantization_offset_is_zero(self):
115-
df = deep_factorized.NoisyDeepFactorized()
116-
self.assertEqual(helpers.quantization_offset(df), 0)
117-
118114
def test_tails_are_in_order(self):
119115
df = deep_factorized.NoisyDeepFactorized()
120116
lower_tail = helpers.lower_tail(df, 2**-8)

tensorflow_compression/python/distributions/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ def lower_tail(distribution, tail_mass):
157157
tail = distribution.quantile(tail_mass / 2)
158158
except NotImplementedError:
159159
try:
160+
target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype))
160161
tail = estimate_tails(
161-
distribution.log_cdf, tf.math.log(tail_mass / 2),
162-
distribution.batch_shape_tensor(), distribution.dtype)
162+
distribution.log_cdf, target, distribution.batch_shape_tensor(),
163+
distribution.dtype)
163164
except NotImplementedError:
164165
raise NotImplementedError(
165166
"`distribution` must implement `_lower_tail()`, `quantile()`, or "
@@ -193,8 +194,9 @@ def upper_tail(distribution, tail_mass):
193194
tail = distribution.quantile(1 - tail_mass / 2)
194195
except NotImplementedError:
195196
try:
197+
target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype))
196198
tail = estimate_tails(
197-
distribution.log_survival_function, tf.math.log(tail_mass / 2),
199+
distribution.log_survival_function, target,
198200
distribution.batch_shape_tensor(), distribution.dtype)
199201
except NotImplementedError:
200202
raise NotImplementedError(

tensorflow_compression/python/distributions/round_adapters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tensorflow_compression.python.distributions import deep_factorized
2020
from tensorflow_compression.python.distributions import helpers
2121
from tensorflow_compression.python.distributions import uniform_noise
22-
from tensorflow_compression.python.ops import soft_round_ops
22+
from tensorflow_compression.python.ops import round_ops
2323

2424

2525
__all__ = [
@@ -239,10 +239,10 @@ def __init__(self, base, alpha, name="SoftRoundAdapter"):
239239
self._alpha = alpha
240240

241241
def transform(self, x):
242-
return soft_round_ops.soft_round(x, self._alpha)
242+
return round_ops.soft_round(x, self._alpha)
243243

244244
def inverse_transform(self, y):
245-
return soft_round_ops.soft_round_inverse(y, self._alpha)
245+
return round_ops.soft_round_inverse(y, self._alpha)
246246

247247

248248
class NoisySoftRoundAdapter(uniform_noise.UniformNoiseAdapter):

tensorflow_compression/python/distributions/round_adapters_test.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import tensorflow_probability as tfp
2121
from tensorflow_compression.python.distributions import deep_factorized
2222
from tensorflow_compression.python.distributions import round_adapters
23-
from tensorflow_compression.python.ops import soft_round_ops
2423

2524

2625
def _test_log_prob_gradient_is_bounded(self, dist_cls, values, params=()):
@@ -42,44 +41,42 @@ class AdaptersTest(tf.test.TestCase, parameterized.TestCase):
4241
@parameterized.named_parameters(
4342
("softround_deepfactorized",
4443
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
45-
deep_factorized.DeepFactorized, 0.0),
44+
deep_factorized.DeepFactorized),
4645
("softround_logistic",
4746
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
48-
lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5),
49-
lambda: soft_round_ops.soft_round(0.3, alpha=5.0)),
47+
lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5)),
5048
("softround_normal",
5149
lambda d: round_adapters.SoftRoundAdapter(d, alpha=4.0),
52-
lambda: tfp.distributions.Normal(loc=10.4, scale=1.5),
53-
lambda: soft_round_ops.soft_round(0.4, alpha=4.0)),
50+
lambda: tfp.distributions.Normal(loc=10.4, scale=1.5)),
5451
("noisysoftround_deepfactorized",
5552
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
56-
deep_factorized.DeepFactorized, 0.0),
53+
deep_factorized.DeepFactorized),
5754
("noisysoftround_logistic",
5855
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
59-
lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 0.0),
56+
lambda: tfp.distributions.Logistic(loc=10, scale=1.5)),
6057
("noisysoftround_normal",
6158
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
62-
lambda: tfp.distributions.Normal(loc=10, scale=1.5), 0.0),
59+
lambda: tfp.distributions.Normal(loc=10, scale=1.5)),
6360
("round_deepfactorized",
6461
round_adapters.RoundAdapter,
65-
lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0),
62+
lambda: deep_factorized.DeepFactorized(init_scale=1.0)),
6663
("round_logistic",
6764
round_adapters.RoundAdapter,
68-
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0),
65+
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)),
6966
("round_normal",
7067
round_adapters.RoundAdapter,
71-
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0),
68+
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)),
7269
("noisyround_deepfactorized",
7370
round_adapters.NoisyRoundAdapter,
74-
lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0),
71+
lambda: deep_factorized.DeepFactorized(init_scale=1.0)),
7572
("noisyround_logistic",
7673
round_adapters.NoisyRoundAdapter,
77-
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0),
74+
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)),
7875
("noisyround_normal",
7976
round_adapters.NoisyRoundAdapter,
80-
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0),
77+
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)),
8178
)
82-
def test_tails_and_offset(self, adapter, distribution, expected_offset):
79+
def test_tails(self, adapter, distribution):
8380
dist = adapter(distribution())
8481
lower_tail = dist._lower_tail(2**-8)
8582
try:
@@ -98,12 +95,6 @@ def test_tails_and_offset(self, adapter, distribution, expected_offset):
9895
self.assertLessEqual(right_mass, 2**-8)
9996

10097
self.assertGreater(upper_tail, lower_tail)
101-
offset = dist._quantization_offset()
102-
if not isinstance(expected_offset, float):
103-
# We cannot run tf inside the parameterized test declaration, hence
104-
# non-float values are wrapped in a lambda.
105-
expected_offset = expected_offset()
106-
self.assertAllClose(offset, expected_offset)
10798

10899
@parameterized.named_parameters(
109100
("softround_logistic",
@@ -210,16 +201,11 @@ def test_sampling_works(self):
210201
sample = dist.sample((5, 4))
211202
self.assertEqual(sample.shape, (5, 4, 2))
212203

213-
def test_tails_and_offset_are_in_order(self):
204+
def test_tails_are_in_order(self):
214205
dist = self.dist_cls(loc=10, scale=1.5)
215-
offset = dist._quantization_offset()
216206
lower_tail = dist._lower_tail(2**-8)
217207
upper_tail = dist._upper_tail(2**-8)
218208
self.assertGreater(upper_tail, lower_tail)
219-
if offset:
220-
# If quantization offset is 0.0, it doesn't need to be between the tails.
221-
self.assertGreater(upper_tail, offset)
222-
self.assertGreater(offset, lower_tail)
223209

224210
def test_stats_throw_error(self):
225211
dist = self.dist_cls(loc=1, scale=2)

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ py_library(
2323
"//tensorflow_compression/python/distributions:helpers",
2424
"//tensorflow_compression/python/ops:gen_ops",
2525
"//tensorflow_compression/python/ops:math_ops",
26+
"//tensorflow_compression/python/ops:round_ops",
2627
],
2728
)
2829

@@ -42,16 +43,17 @@ py_library(
4243
srcs_version = "PY3",
4344
deps = [
4445
":continuous_base",
45-
"//tensorflow_compression/python/distributions:helpers",
4646
"//tensorflow_compression/python/ops:gen_ops",
4747
"//tensorflow_compression/python/ops:math_ops",
48+
"//tensorflow_compression/python/ops:round_ops",
4849
],
4950
)
5051

5152
py_test(
5253
name = "continuous_indexed_test",
5354
srcs = ["continuous_indexed_test.py"],
5455
python_version = "PY3",
56+
shard_count = 5,
5557
deps = [
5658
":continuous_indexed",
5759
"//tensorflow_compression/python/distributions:uniform_noise",
@@ -63,15 +65,14 @@ py_library(
6365
srcs = ["universal.py"],
6466
srcs_version = "PY3",
6567
deps = [
66-
":continuous_batched",
67-
":continuous_indexed",
68+
":continuous_base",
69+
"//tensorflow_compression/python/ops:gen_ops",
6870
"//tensorflow_compression/python/ops:math_ops",
6971
],
7072
)
7173

7274
py_test(
7375
name = "universal_test",
74-
timeout = "long",
7576
srcs = ["universal_test.py"],
7677
python_version = "PY3",
7778
shard_count = 3,

0 commit comments

Comments
 (0)