Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions keras/src/optimizers/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulated_grads = self.add_optimizer_variables(
var_list, "accumulated_grad"
)
self._accumulated_delta_vars = self.add_optimizer_variables(
var_list, "accumulated_delta_var"
self._accumulated_grads, self._accumulated_delta_vars = (
self.add_optimizer_variables(
var_list, ["accumulated_grad", "accumulated_delta_var"]
)
)

def update_step(self, grad, variable, learning_rate):
Expand Down
16 changes: 11 additions & 5 deletions keras/src/optimizers/adafactor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.optimizers import optimizer
Expand Down Expand Up @@ -96,11 +97,16 @@ def build(self, var_list):
self._c = []
self._v = []
for var in var_list:
if (
self._overwrite_variable_with_gradient(var)
or len(var.shape) < 2
):
# Don't factor if variable is of dimension < 2.
if len(var.shape) < 2:
# Don't factor if variable is of dimension < 2, but we still
# need to create dummy variables as placeholder.
self._r.append(
backend.Variable(0, name=var.name, trainable=False)
)
self._c.append(
backend.Variable(0, name=var.name, trainable=False)
)
elif self._overwrite_variable_with_gradient(var):
self._r.append(None)
self._c.append(None)
else:
Expand Down
5 changes: 3 additions & 2 deletions keras/src/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._momentums = self.add_optimizer_variables(var_list, "momentum")
self._velocities = self.add_optimizer_variables(var_list, "velocity")
self._momentums, self._velocities = self.add_optimizer_variables(
var_list, ["momentum", "velocity"]
)

if self.amsgrad:
self._velocity_hats = self.add_optimizer_variables(
Expand Down
5 changes: 3 additions & 2 deletions keras/src/optimizers/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._m = self.add_optimizer_variables(var_list, "momentum")
self._u = self.add_optimizer_variables(var_list, "norm")
self._m, self._u = self.add_optimizer_variables(
var_list, ["momentum", "norm"]
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
63 changes: 49 additions & 14 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,29 +338,64 @@ def add_optimizer_variables(
Args:
trainable_variables: `keras.Variable`, the corresponding model
variable to the optimizer variable to be created.
name: The name prefix of the optimizer variable to be created. The
variable name will follow the pattern
name: The name prefix(es) of the optimizer variable(s) to be
created. Can be a single string or list of strings. If a
list of strings, will create an optimizer variable for each
prefix. The variable name will follow the pattern
`{variable_name}_{trainable_variable.name}`, e.g.,
`momemtum/dense_1`. Defaults to `None`.
initializer: Initializer object to use to populate the initial
variable value, or string name of a built-in initializer (e.g.
`"random_normal"`). If unspecified, defaults to `"zeros"`.
`momemtum/dense_1`.
initializer: Initializer object(s) to use to populate the initial
variable value(s), or string name of a built-in initializer
(e.g. `"random_normal"`). If unspecified, defaults to
`"zeros"`.

Returns:
A list of optimizer variables, in the format of `keras.Variable`s.
If multiple names are provide, returns a tuple of lists.
"""
optimizer_variables = []
name_list = name
initializer_list = initializer
if isinstance(name, str):
# Single name/initializer.
name_list = [name]
initializer_list = [initializer]
else:
# Multiple names/initializers.
# If there is only one initializer, use it for all names.
if isinstance(initializer, str) or isinstance(
initializer, initializers.Initializer
):
initializer_list = [initializer] * len(name_list)

if len(name_list) != len(initializer_list):
raise ValueError(
f"The number of provided names must match the number of "
f"provided initializers. Received name='{name}', "
f"initializer='{initializer}'"
)

# Build up lists of optimizer variables.
optimizer_variables = tuple([] for _ in name_list)
for variable in trainable_variables:
# Interleaves adding variables for backward-compatibility.
if not self._overwrite_variable_with_gradient(variable):
optimizer_variables.append(
self.add_variable_from_reference(
variable,
name=name,
initializer=initializer,
for i, (var_name, var_init) in enumerate(
zip(name_list, initializer_list)
):
optimizer_variables[i].append(
self.add_variable_from_reference(
variable,
name=var_name,
initializer=var_init,
)
)
)
else:
optimizer_variables.append(None)
for i in range(len(name_list)):
optimizer_variables[i].append(None)

# If single input name, return the single list.
if isinstance(name, str):
return optimizer_variables[0]

return optimizer_variables

Expand Down
7 changes: 4 additions & 3 deletions keras/src/optimizers/ftrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ def build(self, var_list):
accumulator_initializer = initializers.Constant(
self.initial_accumulator_value,
)
self._accumulators = self.add_optimizer_variables(
var_list, "accumulator", initializer=accumulator_initializer
self._accumulators, self._linears = self.add_optimizer_variables(
var_list,
["accumulator", "linear"],
initializer=[accumulator_initializer, "zeros"],
)
self._linears = self.add_optimizer_variables(var_list, "linear")

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
5 changes: 3 additions & 2 deletions keras/src/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._momentums = self.add_optimizer_variables(var_list, "momentum")
self._velocities = self.add_optimizer_variables(var_list, "velocity")
self._momentums, self._velocities = self.add_optimizer_variables(
var_list, ["momentum", "velocity"]
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
5 changes: 3 additions & 2 deletions keras/src/optimizers/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def build(self, var_list):
else:
dtype = backend.floatx()
super().build(var_list)
self._momentums = self.add_optimizer_variables(var_list, "momentum")
self._velocities = self.add_optimizer_variables(var_list, "velocity")
self._momentums, self._velocities = self.add_optimizer_variables(
var_list, ["momentum", "velocity"]
)
self._u_product = backend.Variable(1.0, dtype=dtype)

def _backend_update_step(self, grads, trainable_variables, learning_rate):
Expand Down