Skip to content

Commit c0017df

Browse files
authored
Don't create unused optimizer variables. (#21232)
If `variable.overwrite_with_gradient == True`, then the only optimizer variable ever used for that variable is `base_optimizer._accumulated_gradients`. All other optimizer variables are unused. This can be extremely wasteful if the training variables are large, for example in the case of large embedding tables that span multiple hosts/devices. Added a convenience function in the base optimizer `add_optimizer_variables(...)` that loops through the variable list and automatically adds a variable only if appropriate. If a variable would otherwise be unused, a `None` is inserted into the list. This is needed to keep `optimizer._get_variable_index()` consistent. Updated all built-in optimizers to use this. NOTE: if a custom optimizer that exists out in the wild still does create unused optimizer variables, the optimizer should still work - it will just be wasteful. IOW this should not be a breaking change.
1 parent 16ef8fb commit c0017df

File tree

15 files changed

+132
-178
lines changed

15 files changed

+132
-178
lines changed

keras/src/backend/torch/optimizers/torch_parallel_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
1515

1616
@torch_utils.no_grad
1717
def _backend_reset_gradient_accumulators(self):
18-
acc_list = [v.value for v in self._accumulated_gradients]
18+
acc_list = [
19+
v.value for v in self._accumulated_gradients if v is not None
20+
]
1921
torch._foreach_mul_(acc_list, 0.0)
2022

2123
@torch_utils.no_grad

keras/src/optimizers/adadelta.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,12 @@ def build(self, var_list):
7575
if self.built:
7676
return
7777
super().build(var_list)
78-
self._accumulated_grads = []
79-
self._accumulated_delta_vars = []
80-
for var in var_list:
81-
self._accumulated_grads.append(
82-
self.add_variable_from_reference(var, "accumulated_grad")
83-
)
84-
self._accumulated_delta_vars.append(
85-
self.add_variable_from_reference(var, "accumulated_delta_var")
86-
)
78+
self._accumulated_grads = self.add_optimizer_variables(
79+
var_list, "accumulated_grad"
80+
)
81+
self._accumulated_delta_vars = self.add_optimizer_variables(
82+
var_list, "accumulated_delta_var"
83+
)
8784

8885
def update_step(self, grad, variable, learning_rate):
8986
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adafactor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from keras.src import backend
21
from keras.src import ops
32
from keras.src.api_export import keras_export
43
from keras.src.optimizers import optimizer
@@ -97,16 +96,13 @@ def build(self, var_list):
9796
self._c = []
9897
self._v = []
9998
for var in var_list:
100-
if len(var.shape) < 2:
101-
# Don't factor if variable is of dimension < 2, but we still
102-
# need to create dummy variables as placeholder.
103-
with backend.name_scope(self.name, caller=self):
104-
self._r.append(
105-
backend.Variable(0, name=var.name, trainable=False)
106-
)
107-
self._c.append(
108-
backend.Variable(0, name=var.name, trainable=False)
109-
)
99+
if (
100+
self._overwrite_variable_with_gradient(var)
101+
or len(var.shape) < 2
102+
):
103+
# Don't factor if variable is of dimension < 2.
104+
self._r.append(None)
105+
self._c.append(None)
110106
else:
111107
# Always factor the last 2 dimensions.
112108
r_shape = var.shape[:-1]
@@ -125,11 +121,15 @@ def build(self, var_list):
125121
name=var.name,
126122
)
127123
)
128-
self._v.append(
129-
self.add_variable_from_reference(
130-
reference_variable=var, name="velocity"
124+
125+
if self._overwrite_variable_with_gradient(var):
126+
self._v.append(None)
127+
else:
128+
self._v.append(
129+
self.add_variable_from_reference(
130+
reference_variable=var, name="velocity"
131+
)
131132
)
132-
)
133133

134134
def _rms(self, x):
135135
return ops.sqrt(ops.mean(ops.square(x)))

keras/src/optimizers/adagrad.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,10 @@ def build(self, var_list):
7070
if self.built:
7171
return
7272
super().build(var_list)
73-
self._accumulators = []
7473
initializer = initializers.Constant(self.initial_accumulator_value)
75-
for var in var_list:
76-
self._accumulators.append(
77-
self.add_variable(
78-
shape=var.shape,
79-
initializer=initializer,
80-
dtype=var.dtype,
81-
name="accumulator",
82-
)
83-
)
74+
self._accumulators = self.add_optimizer_variables(
75+
var_list, "accumulator", initializer=initializer
76+
)
8477

8578
def update_step(self, gradient, variable, learning_rate):
8679
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adam.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,27 +90,13 @@ def build(self, var_list):
9090
if self.built:
9191
return
9292
super().build(var_list)
93-
self._momentums = []
94-
self._velocities = []
95-
for var in var_list:
96-
self._momentums.append(
97-
self.add_variable_from_reference(
98-
reference_variable=var, name="momentum"
99-
)
100-
)
101-
self._velocities.append(
102-
self.add_variable_from_reference(
103-
reference_variable=var, name="velocity"
104-
)
105-
)
93+
self._momentums = self.add_optimizer_variables(var_list, "momentum")
94+
self._velocities = self.add_optimizer_variables(var_list, "velocity")
95+
10696
if self.amsgrad:
107-
self._velocity_hats = []
108-
for var in var_list:
109-
self._velocity_hats.append(
110-
self.add_variable_from_reference(
111-
reference_variable=var, name="velocity_hat"
112-
)
113-
)
97+
self._velocity_hats = self.add_optimizer_variables(
98+
var_list, "velocity_hat"
99+
)
114100

115101
def update_step(self, gradient, variable, learning_rate):
116102
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adamax.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,8 @@ def build(self, var_list):
9898
if self.built:
9999
return
100100
super().build(var_list)
101-
self._m = []
102-
self._u = []
103-
for var in var_list:
104-
self._m.append(
105-
self.add_variable_from_reference(
106-
reference_variable=var, name="momentum"
107-
)
108-
)
109-
self._u.append(
110-
self.add_variable_from_reference(
111-
reference_variable=var, name="norm"
112-
)
113-
)
101+
self._m = self.add_optimizer_variables(var_list, "momentum")
102+
self._u = self.add_optimizer_variables(var_list, "norm")
114103

115104
def update_step(self, gradient, variable, learning_rate):
116105
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/base_optimizer.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -204,21 +204,19 @@ def iterations(self):
204204
def _track_variable(self, variable):
205205
self._tracker.add_to_store("variables", variable)
206206

207+
def _overwrite_variable_with_gradient(self, variable):
208+
return getattr(variable, "overwrite_with_gradient", False)
209+
207210
@tracking.no_automatic_dependency_tracking
208211
def build(self, variables):
209212
if self.use_ema:
210-
self._model_variables_moving_average = []
213+
self._model_variables_moving_average = self.add_optimizer_variables(
214+
variables, "average"
215+
)
211216
if self.gradient_accumulation_steps:
212217
self._accumulated_gradients = []
213218
for i, variable in enumerate(variables):
214219
self._trainable_variables_indices[self._var_key(variable)] = i
215-
if self.use_ema:
216-
self._model_variables_moving_average.append(
217-
self.add_variable_from_reference(
218-
variable,
219-
name="average",
220-
)
221-
)
222220
if self.gradient_accumulation_steps:
223221
self._accumulated_gradients.append(
224222
self.add_variable_from_reference(
@@ -323,6 +321,49 @@ def add_variable_from_reference(
323321
name=name,
324322
)
325323

324+
def add_optimizer_variables(
325+
self, trainable_variables, name, initializer="zeros"
326+
):
327+
"""Add optimizer variables from the list of trainable model variables.
328+
329+
Create an optimizer variable based on the information of the supplied
330+
model variables. For example, in SGD optimizer momemtum, for each model
331+
variable, a corresponding momemtum variable is created of the same shape
332+
and dtype.
333+
334+
Note that trainable variables with `v.overwrite_with_gradient == True`
335+
will insert `None`, into the output list, since the optimizer variable
336+
will not be used anyways, and could be wasteful.
337+
338+
Args:
339+
trainable_variables: `keras.Variable`, the corresponding model
340+
variable to the optimizer variable to be created.
341+
name: The name prefix of the optimizer variable to be created. The
342+
variable name will follow the pattern
343+
`{variable_name}_{trainable_variable.name}`, e.g.,
344+
`momemtum/dense_1`. Defaults to `None`.
345+
initializer: Initializer object to use to populate the initial
346+
variable value, or string name of a built-in initializer (e.g.
347+
`"random_normal"`). If unspecified, defaults to `"zeros"`.
348+
349+
Returns:
350+
A list of optimizer variables, in the format of `keras.Variable`s.
351+
"""
352+
optimizer_variables = []
353+
for variable in trainable_variables:
354+
if not self._overwrite_variable_with_gradient(variable):
355+
optimizer_variables.append(
356+
self.add_variable_from_reference(
357+
variable,
358+
name=name,
359+
initializer=initializer,
360+
)
361+
)
362+
else:
363+
optimizer_variables.append(None)
364+
365+
return optimizer_variables
366+
326367
def _check_variables_are_known(self, variables):
327368
for v in variables:
328369
if self._var_key(v) not in self._trainable_variables_indices:
@@ -544,7 +585,8 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
544585

545586
def _backend_reset_gradient_accumulators(self):
546587
for g_acc in self._accumulated_gradients:
547-
g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))
588+
if g_acc is not None:
589+
g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))
548590

549591
def _backend_increment_gradient_accumulators(self, grads, acc_grads):
550592
new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)]
@@ -711,8 +753,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
711753
After the update, the processed pairs will be filtered out.
712754
"""
713755
# Shortcut for `tf.Variable` because it doesn't have a
714-
# `overwrite_with_gradient` attr
715-
if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
756+
# `overwrite_with_gradient` attr.
757+
if not any(self._overwrite_variable_with_gradient(v) for v in vars):
716758
return grads, vars
717759

718760
# Shallow copies
@@ -722,7 +764,7 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
722764
# Iterate from right to left for safe popping
723765
for i in range(len(filtered_grads) - 1, -1, -1):
724766
g, v = filtered_grads[i], filtered_vars[i]
725-
if v.overwrite_with_gradient:
767+
if self._overwrite_variable_with_gradient(v):
726768
if self.gradient_accumulation_steps:
727769
# Utilize a stateless manner for JAX compatibility
728770
steps = self.gradient_accumulation_steps
@@ -886,11 +928,12 @@ def _update_model_variables_moving_average(self, trainable_variables):
886928
for var, average in zip(
887929
trainable_variables, self._model_variables_moving_average
888930
):
889-
not_first_step = ops.not_equal(self.iterations, 0)
890-
momentum = (
891-
ops.cast(not_first_step, var.dtype) * self.ema_momentum
892-
)
893-
average.assign(momentum * average + (1 - momentum) * var)
931+
if average is not None:
932+
not_first_step = ops.not_equal(self.iterations, 0)
933+
momentum = (
934+
ops.cast(not_first_step, var.dtype) * self.ema_momentum
935+
)
936+
average.assign(momentum * average + (1 - momentum) * var)
894937

895938
def _overwrite_model_variables_with_average_value(
896939
self, trainable_variables
@@ -909,7 +952,8 @@ def _overwrite_model_variables_with_average_value(
909952
for var, average_var in zip(
910953
trainable_variables, self._model_variables_moving_average
911954
):
912-
var.assign(average_var)
955+
if average_var is not None:
956+
var.assign(average_var)
913957

914958
def finalize_variable_values(self, var_list):
915959
"""Set the final value of model's trainable variables.

keras/src/optimizers/ftrl.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,13 @@ def build(self, var_list):
159159
if self.built:
160160
return
161161
super().build(var_list)
162-
self._accumulators = []
163-
self._linears = []
164-
for var in var_list:
165-
self._accumulators.append(
166-
self.add_variable(
167-
shape=var.shape,
168-
dtype=var.dtype,
169-
name="accumulator",
170-
initializer=initializers.Constant(
171-
self.initial_accumulator_value,
172-
),
173-
)
174-
)
175-
self._linears.append(
176-
self.add_variable_from_reference(
177-
reference_variable=var, name="linear"
178-
)
179-
)
162+
accumulator_initializer = initializers.Constant(
163+
self.initial_accumulator_value,
164+
)
165+
self._accumulators = self.add_optimizer_variables(
166+
var_list, "accumulator", initializer=accumulator_initializer
167+
)
168+
self._linears = self.add_optimizer_variables(var_list, "linear")
180169

181170
def update_step(self, gradient, variable, learning_rate):
182171
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/lamb.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,8 @@ def build(self, var_list):
8282
if self.built:
8383
return
8484
super().build(var_list)
85-
self._momentums = []
86-
self._velocities = []
87-
for var in var_list:
88-
self._momentums.append(
89-
self.add_variable_from_reference(
90-
reference_variable=var, name="momentum"
91-
)
92-
)
93-
self._velocities.append(
94-
self.add_variable_from_reference(
95-
reference_variable=var, name="velocity"
96-
)
97-
)
85+
self._momentums = self.add_optimizer_variables(var_list, "momentum")
86+
self._velocities = self.add_optimizer_variables(var_list, "velocity")
9887

9988
def update_step(self, gradient, variable, learning_rate):
10089
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/lion.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,7 @@ def build(self, var_list):
9191
if self.built:
9292
return
9393
super().build(var_list)
94-
self._momentums = []
95-
for var in var_list:
96-
self._momentums.append(
97-
self.add_variable_from_reference(
98-
reference_variable=var, name="momentum"
99-
)
100-
)
94+
self._momentums = self.add_optimizer_variables(var_list, "momentum")
10195

10296
def update_step(self, gradient, variable, learning_rate):
10397
"""Update step given gradient and the associated model variable."""

0 commit comments

Comments
 (0)