Skip to content

Commit cbb3682

Browse files
authored
Allow variable layout to be set on construction. (#21264)
If a variable has a particular layout (e.g. is sharded across multiple devices/hosts), then any corresponding auxiliary variables, like optimizer gradient accumulators, need to have the same layout. This requires use to set the variable layout prior to initialization so that it is initialized correctly and efficiently across devices. Added optional `kwargs` to the base `Variable` class so they can handle (and ignore) any backend-specific options. Modified `JaxVariable` to allow setting the layout parameter on construction. Modified `add_variable_from_reference` to copy the layout from the reference variable.
1 parent 583306d commit cbb3682

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

keras/src/backend/common/variables.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Variable:
5151
value: The current value of the variable (NumPy array or tensor).
5252
name: The name of the variable (string).
5353
path: The path of the variable within the Keras model or layer (string).
54+
kwargs: Additional backend-specific keyword arguments.
5455
5556
Examples:
5657
@@ -98,7 +99,9 @@ def __init__(
9899
aggregation="none",
99100
synchronization="auto",
100101
name=None,
102+
**kwargs,
101103
):
104+
del kwargs
102105
name = name or auto_name(self.__class__.__name__)
103106
if not isinstance(name, str) or "/" in name:
104107
raise ValueError(

keras/src/backend/jax/core.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,30 @@
2020

2121

2222
class Variable(KerasVariable):
23+
def __init__(self, *args, layout=None, **kwargs):
24+
# Intercept layout parameter so that it is available
25+
# during initialization.
26+
self._layout = layout
27+
super().__init__(*args, **kwargs)
28+
2329
def _initialize(self, value):
2430
# Note that variable.shape is needed by distribution_lib
2531
self._shape = self._validate_shape(value.shape)
2632
# We can't import the keras/distribution/distribution_lib
2733
# due to circular dependency.
2834
distribution = global_state.get_global_attribute("distribution")
29-
if distribution is not None:
30-
self._layout = distribution.get_variable_layout(self).backend_layout
31-
else:
32-
self._layout = None
35+
if self._layout is None and distribution is not None:
36+
tensor_layout = distribution.get_variable_layout(self)
37+
from keras.src.distribution import TensorLayout
38+
39+
if isinstance(tensor_layout, TensorLayout):
40+
self._layout = tensor_layout.backend_layout
41+
else:
42+
self._layout = tensor_layout
3343
self._direct_assign(value)
3444

3545
def _direct_assign(self, value):
36-
if getattr(self, "_layout", None) is not None:
46+
if self._layout is not None:
3747
value = distribution_lib.distribute_variable(value, self._layout)
3848
self._value = value
3949

keras/src/optimizers/base_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def add_variable(
244244
initializer="zeros",
245245
dtype=None,
246246
aggregation="none",
247+
layout=None,
247248
name=None,
248249
):
249250
"""Add a variable to the optimizer.
@@ -261,6 +262,7 @@ def add_variable(
261262
the type of multi-replica aggregation to be used for this
262263
variable when writing custom data parallel training loops.
263264
Defaults to `"none"`.
265+
layout: Optional tensor layout. Defaults to `None`.
264266
name: String name of the variable. Useful for debugging purposes.
265267
266268
Returns:
@@ -275,6 +277,7 @@ def add_variable(
275277
dtype=dtype,
276278
trainable=False,
277279
aggregation=aggregation,
280+
layout=layout,
278281
name=name,
279282
)
280283
self._track_variable(variable)
@@ -319,6 +322,7 @@ def add_variable_from_reference(
319322
initializer=initializer,
320323
dtype=reference_variable.dtype,
321324
name=name,
325+
layout=getattr(reference_variable, "_layout", None),
322326
)
323327

324328
def add_optimizer_variables(

0 commit comments

Comments
 (0)