Skip to content

Commit 72379f9

Browse files
jerryxyjOrbax Authors
authored andcommitted
Rename strip_xla_flags to persist_xla_flags and add validation.
PiperOrigin-RevId: 834824081
1 parent d966ddf commit 72379f9

File tree

8 files changed

+155
-69
lines changed

8 files changed

+155
-69
lines changed

export/orbax/export/constants.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,13 @@ class ExportModelType(enum.Enum):
9797
# Mesh for the model.
9898
JAX_MESH = 'jax_mesh'
9999

100-
# Whether to strip XLA flags from the model.
101-
STRIP_XLA_FLAGS = 'strip_xla_flags'
100+
# Whether to persist XLA flags in the model.
101+
PERSIST_XLA_FLAGS = 'persist_xla_flags'
102+
103+
# Whether to enable bf16 optimization for the model.
104+
# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2):
105+
# Adding filter flags once the flag is applied to the pre/post processors.
106+
ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization'
102107

103108
################################################################################
104109
# Proto field names

export/orbax/export/jax_module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
197197
tensorflow_module.TensorFlowModule, self._export_module
198198
).jax2tf_kwargs_map
199199

200+
@property
201+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
202+
"""Returns the jax2obm_kwargs."""
203+
if self._export_version == constants.ExportModelType.TF_SAVEDMODEL:
204+
raise TypeError(
205+
'jax2obm_kwargs is not implemented for export version'
206+
' ExportModelType.TF_SAVEDMODEL.'
207+
)
208+
return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs
209+
200210
@property
201211
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
202212
"""Returns the polymorphic shapes."""

export/orbax/export/modules/obm_module.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,40 @@ def __init__(
7373
)
7474

7575
# It is possible for jax2obm_kwargs to be None if the key is present.
76-
if not jax2obm_kwargs:
77-
jax2obm_kwargs = {}
7876

77+
self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {}
78+
79+
enable_bf16_optimization = self.jax2obm_kwargs.get(
80+
constants.ENABLE_BF16_OPTIMIZATION, False
81+
)
82+
83+
if enable_bf16_optimization:
84+
mapped_apply_fn = utils.to_bfloat16(apply_fn)
85+
self._params_args_spec = utils.to_bfloat16(params)
86+
else:
87+
mapped_apply_fn = apply_fn
88+
self._params_args_spec = params
7989
(
8090
self._apply_fn_map,
8191
self.input_polymorphic_shape_map,
8292
self.input_polymorphic_shape_symbol_values_map,
8393
) = self._normalize_apply_fn_map(
84-
apply_fn,
94+
mapped_apply_fn,
8595
input_polymorphic_shape,
8696
input_polymorphic_shape_symbol_values,
8797
)
8898

89-
self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None)
90-
self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False)
99+
self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None)
91100

92-
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints(
93-
jax2obm_kwargs
94-
)
101+
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints()
95102
self._native_serialization_platforms = utils.get_lowering_platforms(
96-
jax2obm_kwargs
103+
self.jax2obm_kwargs
97104
)
98-
self._params_args_spec = params
99105

100106
self._checkpoint_path: str = None
101107
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
102-
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
103-
self._load_all_checkpoint_weights = jax2obm_kwargs.get(
108+
self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs)
109+
self._load_all_checkpoint_weights = self.jax2obm_kwargs.get(
104110
constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False
105111
)
106112

@@ -203,15 +209,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
203209
else constants.DEFAULT_WEIGHTS_NAME
204210
)
205211

206-
def _maybe_set_polymorphic_constraints(
207-
self, jax2obm_kwargs
208-
) -> Mapping[str, Sequence[Any]]:
212+
def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]:
209213
"""Sets the polymorphic constraints for the model.
210214
211-
Args:
212-
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
213-
library.
214-
215215
Returns:
216216
A mapping of function name to polymorphic constraints.
217217
@@ -221,7 +221,7 @@ def _maybe_set_polymorphic_constraints(
221221
size of the apply_fn_map or if a key in apply_fn_map is not found in
222222
polymorphic_constraints.
223223
"""
224-
polymorphic_constraints = jax2obm_kwargs.get(
224+
polymorphic_constraints = self.jax2obm_kwargs.get(
225225
constants.POLYMORPHIC_CONSTRAINTS, None
226226
)
227227
if not isinstance(polymorphic_constraints, Mapping):
@@ -300,3 +300,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
300300
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
301301
"""Named methods in JAX context for validation."""
302302
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')
303+
304+
@property
305+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
306+
"""Returns the jax2obm_kwargs."""
307+
return self._jax2obm_kwargs

export/orbax/export/modules/obm_module_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns(
357357
jax2obm_kwargs=jax2obm_kwargs,
358358
)
359359

360+
@parameterized.named_parameters(
361+
{'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True},
362+
{'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False},
363+
)
364+
def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
365+
params_spec = {
366+
'w': jax.ShapeDtypeStruct((2, 2), jnp.float32),
367+
'b': jax.ShapeDtypeStruct((2,), jnp.float32),
368+
}
369+
input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'}
370+
371+
module = obm_module.ObmModule(
372+
params=params_spec,
373+
apply_fn=_linear,
374+
input_polymorphic_shape=input_spec,
375+
jax2obm_kwargs={
376+
constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization
377+
},
378+
)
379+
380+
expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32
381+
with self.subTest('test_weights_w_dtype'):
382+
self.assertEqual(module.model_params['w'].dtype, expected_dtype)
383+
with self.subTest('test_weights_b_dtype'):
384+
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
385+
360386

361387
if __name__ == '__main__':
362388
absltest.main()

export/orbax/export/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import inspect
21+
import jax.numpy as jnp
2122
import os
2223
from typing import Any, Callable, List, Optional, Tuple, Union
2324

@@ -532,3 +533,40 @@ def get_lowering_platforms(
532533
)
533534

534535
return native_serialization_platforms
536+
537+
538+
def to_bfloat16(x: Any) -> Any:
539+
"""Helper to convert leaves of a pytree to bfloat16.
540+
541+
It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with
542+
a floating point `dtype`.
543+
544+
Args:
545+
x: The input pytree to convert.
546+
547+
Returns:
548+
The input `x` with floating point values converted to `jnp.bfloat16`.
549+
"""
550+
551+
def _to_bfloat16_leaf(x: Any) -> Any:
552+
if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype(
553+
x.dtype, jnp.floating
554+
):
555+
return jax.ShapeDtypeStruct(
556+
x.shape,
557+
jnp.bfloat16,
558+
sharding=x.sharding,
559+
)
560+
if isinstance(x, jax.ShapeDtypeStruct):
561+
return x
562+
if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating):
563+
return x.astype(jnp.bfloat16)
564+
if isinstance(x, float):
565+
return jnp.bfloat16(x)
566+
return x
567+
568+
flattened_x, treedef = jax.tree_util.tree_flatten(x)
569+
flattened_y = [
570+
jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x
571+
]
572+
return jax.tree_util.tree_unflatten(treedef, flattened_y)

model/orbax/experimental/model/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
A command-line tool for inspecting Orbax models.
44

5+
56
## Examples
67

78
To inspect the model:

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def generate_xla_compile_options(
115115
native_serialization_platforms: Sequence[str] | None,
116116
xla_flags_per_platform: Mapping[str, Sequence[str] | None],
117117
jax_mesh: jax.sharding.Mesh | None = None,
118-
strip_xla_flags: bool = False,
118+
persist_xla_flags: bool = False,
119119
) -> manifest_pb2.CompileOptionsProtoMap:
120120
"""Sets the XLA compilation options.
121121
@@ -127,7 +127,7 @@ def generate_xla_compile_options(
127127
which will be used to override the default XLA compilation flags.
128128
jax_mesh: The JAX mesh used for sharding. If None, the compile options will
129129
be set for a default single-replica.
130-
strip_xla_flags: Whether to strip XLA flags from the compile options.
130+
persist_xla_flags: Whether to persist XLA flags in the compile options.
131131
132132
Returns:
133133
A `CompileOptionsProtoMap` containing the XLA compilation options per
@@ -140,6 +140,9 @@ def generate_xla_compile_options(
140140
ValueError: If a platform is provided for XLA flags which is not provided
141141
in the native serialization platforms.
142142
ValueError: If the supplied XLA flag overrides cannot be parsed.
143+
ValueError: If `xla_flags` are provided but `persist_xla_flags` is False.
144+
This ensures that the XLA flags are persisted in the compile options,
145+
otherwise they would be lost, leading to unexpected behavior.
143146
"""
144147
tpu_platform_name = manifest_pb2.Platform.Name(
145148
manifest_pb2.Platform.TPU
@@ -183,16 +186,40 @@ def generate_xla_compile_options(
183186
xla_flags = None
184187
if xla_flags_per_platform:
185188
xla_flags = xla_flags_per_platform.get(platform, None)
189+
_validate_xla_flags_setting(xla_flags, persist_xla_flags)
186190
compile_environment = generate_tpu_compilation_env(xla_flags)
187191
compile_options_map.map[platform.lower()].CopyFrom(
188192
generate_compilation_options(compile_environment, jax_mesh)
189193
)
190-
if strip_xla_flags:
194+
if not persist_xla_flags:
191195
for compile_options in compile_options_map.map.values():
192196
compile_options.executable_build_options.comp_envs.Clear()
193197
return compile_options_map
194198

195199

200+
def _validate_xla_flags_setting(
201+
xla_flags: Sequence[str] | None, persist_xla_flags: bool
202+
) -> None:
203+
"""Validates the XLA flags setting.
204+
205+
Ensures that if `xla_flags` are provided, `persist_xla_flags` is set to True.
206+
If `xla_flags` are provided but not persisted, they would be lost, leading
207+
to unexpected behavior.
208+
209+
Args:
210+
xla_flags: A sequence of XLA flags provided for overriding. Can be None.
211+
persist_xla_flags: A boolean indicating whether the XLA flags should be
212+
persisted in the compile options.
213+
214+
Raises:
215+
ValueError: If `xla_flags` is not None but `persist_xla_flags` is False.
216+
"""
217+
if xla_flags is not None and not persist_xla_flags:
218+
raise ValueError(
219+
'persist_xla_flags must be True if xla_flags are provided.'
220+
)
221+
222+
196223
def get_field_for_flag(flag_name: str) -> descriptor.FieldDescriptor:
197224
"""Gets the protobuf field descriptor for a given flag name."""
198225
if flag_name not in _XLA_FLAG_TO_FIELD_MAP:

model/orbax/experimental/model/core/python/compile_options_util_test.py

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
197197
native_serialization_platforms=None,
198198
xla_flags_per_platform=None,
199199
expected_platforms=['tpu'],
200+
persist_xla_flags=False,
200201
),
201202
dict(
202203
testcase_name='no_native_serialization_platforms_with_xla_flags',
@@ -205,12 +206,14 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
205206
'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()]
206207
},
207208
expected_platforms=['tpu'],
209+
persist_xla_flags=True,
208210
),
209211
dict(
210212
testcase_name='with_native_serialization_platforms_no_xla_flags',
211213
native_serialization_platforms=['cpu', 'tpu', 'cuda'],
212214
xla_flags_per_platform=None,
213215
expected_platforms=['cpu', 'tpu', 'cuda'],
216+
persist_xla_flags=False,
214217
),
215218
dict(
216219
testcase_name='with_native_serialization_platforms_with_xla_flags',
@@ -219,25 +222,28 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
219222
'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()]
220223
},
221224
expected_platforms=['cpu', 'tpu', 'cuda'],
225+
persist_xla_flags=True,
222226
),
223227
)
224228
def test_generate_xla_compile_options_flags_and_platforms(
225229
self,
226230
native_serialization_platforms,
227231
xla_flags_per_platform,
228232
expected_platforms,
233+
persist_xla_flags,
229234
):
230235
compile_options_map = compile_options_util.generate_xla_compile_options(
231236
native_serialization_platforms=native_serialization_platforms,
232237
xla_flags_per_platform=xla_flags_per_platform,
238+
persist_xla_flags=persist_xla_flags,
233239
)
234240
self.assertLen(compile_options_map.map, len(expected_platforms))
235241

236242
for platform in expected_platforms:
237243
self.assertIn(platform, compile_options_map.map)
238244
compile_options = compile_options_map.map[platform]
239245

240-
if platform != 'tpu':
246+
if platform != 'tpu' or not persist_xla_flags:
241247
self.assertEmpty(
242248
compile_options.executable_build_options.comp_envs.environments
243249
)
@@ -307,49 +313,17 @@ def test_generate_xla_compile_options_xla_flags_platform_not_in_native_serializa
307313
},
308314
)
309315

310-
@parameterized.named_parameters(
311-
dict(testcase_name='strip_xla_flags_true', strip_xla_flags=True),
312-
dict(testcase_name='strip_xla_flags_false', strip_xla_flags=False),
313-
)
314-
def test_generate_xla_compile_options_strip_xla_flags(self, strip_xla_flags):
315-
xla_flags_per_platform = {
316-
'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()]
317-
}
318-
compile_options_map = compile_options_util.generate_xla_compile_options(
319-
native_serialization_platforms=['cpu', 'tpu', 'cuda'],
320-
xla_flags_per_platform=xla_flags_per_platform,
321-
strip_xla_flags=strip_xla_flags,
322-
)
323-
self.assertLen(compile_options_map.map, 3)
324-
for platform in ['cpu', 'tpu', 'cuda']:
325-
self.assertIn(platform, compile_options_map.map)
326-
compile_options = compile_options_map.map[platform]
327-
328-
if strip_xla_flags or platform != 'tpu':
329-
self.assertEmpty(
330-
compile_options.executable_build_options.comp_envs.environments
331-
)
332-
else:
333-
# For TPU platform when not stripping, it should have xla flags.
334-
self.assertLen(
335-
compile_options.executable_build_options.comp_envs.environments, 1
336-
)
337-
actual_env_proto = tpu_comp_env_pb2.TpuCompilationEnvironment()
338-
compile_options.executable_build_options.comp_envs.environments[
339-
0
340-
].Unpack(actual_env_proto)
341-
342-
expected_env_overrides = EXPECTED_ENV
343-
expected_env_proto = tpu_comp_env_pb2.TpuCompilationEnvironment()
344-
expected_env_proto.ParseFromString(
345-
tpu_comp_env.create_default_tpu_comp_env()
346-
)
347-
expected_env_proto.MergeFrom(expected_env_overrides)
348-
349-
self.assertEqual(
350-
text_format.MessageToString(actual_env_proto),
351-
text_format.MessageToString(expected_env_proto),
352-
)
316+
def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self):
317+
with self.assertRaisesWithLiteralMatch(
318+
ValueError, 'persist_xla_flags must be True if xla_flags are provided.'
319+
):
320+
compile_options_util.generate_xla_compile_options(
321+
native_serialization_platforms=['tpu'],
322+
xla_flags_per_platform={
323+
'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()]
324+
},
325+
persist_xla_flags=False,
326+
)
353327

354328
@parameterized.named_parameters(
355329
dict(

0 commit comments

Comments
 (0)