Skip to content

Get compatible with optimizer migration in TF 2.11 #2766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _resource_apply_dense(self, grad, var):
sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)

m_t = m.assign(
beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking
beta_1_t * m + (1.0 - beta_1_t) * grad,
use_locking=self._use_locking,
)
m_corr_t = m_t / (1.0 - beta_1_power)

Expand Down
25 changes: 18 additions & 7 deletions tensorflow_addons/optimizers/average_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,29 @@
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
@typechecked
def __init__(
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
self,
optimizer: types.Optimizer,
name: str = "AverageOptimizer",
**kwargs,
):
super().__init__(name, **kwargs)

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think that we could call an external function with this logic instead of replicating the code in every wrapper?

hasattr(tf.keras.optimizers, "legacy")
and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
):
optimizer = tf.keras.optimizers.get(
optimizer, use_legacy_optimizer=True
)
else:
optimizer = tf.keras.optimizers.get(optimizer)

if not isinstance(
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
):
if not isinstance(optimizer, KerasLegacyOptimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer "
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
"or tf.keras.optimizers.legacy.Optimizer "
"(if you have tf version >= 2.11.0)."
)

self._optimizer = optimizer
Expand Down Expand Up @@ -135,7 +145,8 @@ def assign_average_vars(self, var_list):
try:
assign_ops.append(
var.assign(
self.get_slot(var, "average"), use_locking=self._use_locking
self.get_slot(var, "average"),
use_locking=self._use_locking,
)
)
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions tensorflow_addons/optimizers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import importlib
import tensorflow as tf

if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
if (
hasattr(tf.keras.optimizers, "experimental")
and tf.keras.optimizers.Optimizer.__module__
== tf.keras.optimizers.experimental.Optimizer.__module__
):
# If the default optimizer points to new Keras optimizer, addon optimizers
# should use the legacy path.
KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer
else:
KerasLegacyOptimizer = tf.keras.optimizers.Optimizer
22 changes: 15 additions & 7 deletions tensorflow_addons/optimizers/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,19 @@ def __init__(
super().__init__(name, **kwargs)

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
):
if (
hasattr(tf.keras.optimizers, "legacy")
and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
):
optimizer = tf.keras.optimizers.get(
optimizer, use_legacy_optimizer=True
)
else:
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(optimizer, KerasLegacyOptimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer "
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.11.0)."
)

self._optimizer = optimizer
Expand Down Expand Up @@ -119,10 +125,12 @@ def _look_ahead_op(self, var):
)
with tf.control_dependencies([step_back]):
slow_update = slow_var.assign(
tf.where(sync_cond, step_back, slow_var), use_locking=self._use_locking
tf.where(sync_cond, step_back, slow_var),
use_locking=self._use_locking,
)
var_update = var.assign(
tf.where(sync_cond, step_back, var), use_locking=self._use_locking
tf.where(sync_cond, step_back, var),
use_locking=self._use_locking,
)
return tf.group(slow_update, var_update)

Expand Down
79 changes: 65 additions & 14 deletions tensorflow_addons/optimizers/tests/moving_average_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def test_run():

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5)
else:
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)

opt.apply_gradients(grads_and_vars)
opt.apply_gradients(grads_and_vars)
Expand Down Expand Up @@ -95,7 +98,10 @@ def test_model_weights_update():
)
model.build(input_shape=[1, 1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5)
else:
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
_ = opt.apply_gradients(list(zip([grad], model.variables)))
np.testing.assert_allclose(model.variables[0].read_value(), [[0.8]])
_ = opt.assign_average_vars(model.variables)
Expand All @@ -115,8 +121,10 @@ def test_model_dynamic_lr():
]
)
model.build(input_shape=[1, 1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=1e-3), average_decay=0.5)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=1e-3), average_decay=0.5)
else:
opt = MovingAverage(tf.keras.optimizers.SGD(lr=1e-3), average_decay=0.5)
_ = opt.apply_gradients(list(zip([grad], model.variables)))
np.testing.assert_allclose(opt.lr.read_value(), 1e-3)
opt.lr = 1e-4
Expand All @@ -129,9 +137,20 @@ def test_optimizer_string():


def test_config():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
if hasattr(tf.keras.optimizers, "legacy"):
sgd_opt = tf.keras.optimizers.legacy.SGD(
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
)
else:
sgd_opt = tf.keras.optimizers.SGD(
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
)
opt = MovingAverage(
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True
sgd_opt,
average_decay=0.5,
num_updates=None,
start_step=5,
dynamic_decay=True,
)
config = opt.get_config()

Expand Down Expand Up @@ -177,9 +196,20 @@ def test_fit_simple_linear_model():


def test_serialization():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
if hasattr(tf.keras.optimizers, "legacy"):
sgd_opt = tf.keras.optimizers.legacy.SGD(
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
)
else:
sgd_opt = tf.keras.optimizers.SGD(
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
)
optimizer = MovingAverage(
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True
sgd_opt,
average_decay=0.5,
num_updates=None,
start_step=5,
dynamic_decay=True,
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
Expand Down Expand Up @@ -215,9 +245,18 @@ def test_dynamic_decay():
grads0 = tf.constant([0.1, 0.1])
grads_and_vars = [(grads0, var0)]

opt = MovingAverage(
tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5, dynamic_decay=True
)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(
tf.keras.optimizers.legacy.SGD(lr=2.0),
average_decay=0.5,
dynamic_decay=True,
)
else:
opt = MovingAverage(
tf.keras.optimizers.SGD(lr=2.0),
average_decay=0.5,
dynamic_decay=True,
)

opt.apply_gradients(grads_and_vars)
opt.apply_gradients(grads_and_vars)
Expand All @@ -235,7 +274,12 @@ def test_swap_weight_no_shadow_copy(device):
var = tf.Variable([1.0, 2.0])
grads = tf.constant([0.1, 0.1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(
tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5
)
else:
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)

@tf.function
def apply_gradients():
Expand Down Expand Up @@ -267,7 +311,12 @@ def test_swap_weights(device):
var = tf.Variable([1.0, 2.0])
grads = tf.constant([0.1, 0.1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
if hasattr(tf.keras.optimizers, "legacy"):
opt = MovingAverage(
tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5
)
else:
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)

@tf.function
def apply_gradients():
Expand Down Expand Up @@ -314,7 +363,9 @@ def test_no_average_slot():
# They are returned when using model.variables
# but it's unable to assign average slot to them.
vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
max_tokens=max_features, output_mode="int", output_sequence_length=max_len
max_tokens=max_features,
output_mode="int",
output_sequence_length=max_len,
)

vectorize_layer.adapt(["foo", "bar", "baz"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
def test_averaging():
start_averaging = 0
average_period = 1
sgd = tf.keras.optimizers.SGD(lr=1.0)
if hasattr(tf.keras.optimizers, "legacy"):
sgd = tf.keras.optimizers.legacy.SGD(lr=1.0)
else:
sgd = tf.keras.optimizers.SGD(lr=1.0)
optimizer = SWA(sgd, start_averaging, average_period)

val_0 = [1.0, 1.0]
Expand Down Expand Up @@ -81,7 +84,10 @@ def test_assign_batchnorm():
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(1))

opt = SWA(tf.keras.optimizers.SGD())
if hasattr(tf.keras.optimizers, "legacy"):
opt = SWA(tf.keras.optimizers.legacy.SGD())
else:
opt = SWA(tf.keras.optimizers.SGD())
model.compile(optimizer=opt, loss="mean_squared_error")
model.fit(x, y, epochs=1)

Expand Down Expand Up @@ -118,7 +124,10 @@ def test_fit_simple_linear_model():
def test_serialization():
start_averaging = 0
average_period = 1
sgd = tf.keras.optimizers.SGD(lr=1.0)
if hasattr(tf.keras.optimizers, "legacy"):
sgd = tf.keras.optimizers.legacy.SGD(lr=1.0)
else:
sgd = tf.keras.optimizers.SGD(lr=1.0)
optimizer = SWA(sgd, start_averaging, average_period)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
Expand Down
21 changes: 18 additions & 3 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def from_config(cls, config, custom_objects=None):
return cls(**config)

def minimize(
self, loss, var_list, grad_loss=None, name=None, decay_var_list=None, tape=None
self,
loss,
var_list,
grad_loss=None,
name=None,
decay_var_list=None,
tape=None,
):
"""Minimize `loss` by updating `var_list`.

Expand Down Expand Up @@ -354,7 +360,10 @@ class OptimizerWithDecoupledWeightDecay(

@typechecked
def __init__(
self, weight_decay: Union[FloatTensorLike, Callable], *args, **kwargs
self,
weight_decay: Union[FloatTensorLike, Callable],
*args,
**kwargs,
):
# super delegation is necessary here
super().__init__(weight_decay, *args, **kwargs)
Expand Down Expand Up @@ -441,8 +450,14 @@ def __init__(
)


if hasattr(tf.keras.optimizers, "legacy"):
ADAM_CLASS = tf.keras.optimizers.legacy.Adam
else:
ADAM_CLASS = tf.keras.optimizers.Adam


@tf.keras.utils.register_keras_serializable(package="Addons")
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS):
"""Optimizer that implements the Adam algorithm with weight decay.

This is an implementation of the AdamW optimizer described in "Decoupled
Expand Down
7 changes: 6 additions & 1 deletion tools/testing/source_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_api_typed():
# Files within this list will be exempt from verification.
exception_list = [
tfa.rnn.PeepholeLSTMCell,
tf.keras.optimizers.Optimizer,
]
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
exception_list.append(tf.keras.optimizers.legacy.Optimizer)
Expand All @@ -50,7 +51,10 @@ def test_api_typed():
"https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#about-type-hints"
)
ensure_api_is_typed(
modules_list, exception_list, init_only=True, additional_message=help_message
modules_list,
exception_list,
init_only=True,
additional_message=help_message,
)


Expand Down Expand Up @@ -151,6 +155,7 @@ def test_no_experimental_api():
# TODO: remove all elements of the list and remove the allowlist
# This allowlist should not grow. Do not add elements to this list.
allowlist = [
"tensorflow_addons/optimizers/constants.py",
"tensorflow_addons/optimizers/weight_decay_optimizers.py",
"tensorflow_addons/layers/max_unpooling_2d.py",
"tensorflow_addons/image/dense_image_warp.py",
Expand Down