Skip to content

Commit ed8f382

Browse files
author
Mark Lee
authored
Allow metrics layers to have state. (apple#978)
* Allow metrics layers to have state. * Move BaseLossMetrics to a new file.
1 parent b130416 commit ed8f382

File tree

9 files changed

+92
-80
lines changed

9 files changed

+92
-80
lines changed

axlearn/common/causal_lm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from axlearn.common.layers import LayerNorm
3434
from axlearn.common.logit_modifiers import LogitsToLogitsFn
3535
from axlearn.common.loss import cross_entropy
36-
from axlearn.common.metrics import BaseLossMetrics, WeightedScalar
36+
from axlearn.common.loss_metrics import BaseLossMetrics
37+
from axlearn.common.metrics import WeightedScalar
3738
from axlearn.common.module import Module, NestedTensor, Tensor, child_context
3839
from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer
3940
from axlearn.common.utils import (
@@ -48,6 +49,7 @@ def layer_norm_config(eps=1e-5):
4849
return LayerNorm.default_config().set(eps=eps)
4950

5051

52+
# TODO(markblee): Move these to `axlearn.common.loss_metrics` and update golden configs.
5153
class CrossEntropyLossMetrics(BaseLossMetrics):
5254
"""Computes cross entropy loss and related training summaries."""
5355

axlearn/common/causal_lm_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from axlearn.common.config import config_for_function
2626
from axlearn.common.learner import Learner
2727
from axlearn.common.loss import cross_entropy
28-
from axlearn.common.metrics import BaseLossMetrics, MetricAccumulator, WeightedScalar
28+
from axlearn.common.loss_metrics import BaseLossMetrics
29+
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
2930
from axlearn.common.module import (
3031
InvocationContext,
3132
OutputCollection,

axlearn/common/learner.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Applying updates on non-differentiable params such as batch norm stats;
77
- Maintaining Polyak averages of model params (if enabled).
88
"""
9+
910
from __future__ import annotations
1011

1112
import dataclasses
@@ -47,7 +48,7 @@
4748
Tensor,
4849
flatten_items,
4950
match_regex_rules,
50-
prune_tree,
51+
prune_empty,
5152
register_per_param_settings,
5253
tree_paths,
5354
)
@@ -91,22 +92,6 @@ def should_apply_state_updates(update_type: UpdateType) -> bool:
9192
return update_type in (UpdateType.STATE_UPDATES, UpdateType.ALL_UPDATES)
9293

9394

94-
def _prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]:
95-
"""Returns a shallow copy of the input tree with empty subtrees pruned.
96-
97-
If a tree would be made empty by removal of its subtrees, it will also be pruned.
98-
This is a shallow copy because leaf nodes (non-dict values) are not deep-copied.
99-
100-
Args:
101-
in_tree: the input tree to be pruned.
102-
103-
Returns:
104-
The pruned copy of the input tree.
105-
"""
106-
# Note that falsey values or empty Tensors are not considered empty.
107-
return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v)
108-
109-
11095
class BaseLearner(LearnerModule):
11196
"""The base class of a learner."""
11297

@@ -309,7 +294,7 @@ def _compute_updated_params(
309294
updated_model_params = optax.apply_updates(
310295
jax.tree.map(lambda op: op.value, opt_params), parameter_updates
311296
)
312-
state_updates = _prune_empty(state_updates)
297+
state_updates = prune_empty(state_updates)
313298
apply_state_updates = jax.tree.map(
314299
should_apply_state_updates,
315300
self._update_types(state_updates),

axlearn/common/learner_test.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright © 2023 Apple Inc.
22
"""Tests learner."""
3+
34
import copy
45
import re
56
from numbers import Number
@@ -24,7 +25,6 @@
2425
Learner,
2526
UpdateType,
2627
_apply_updates,
27-
_prune_empty,
2828
_split_gradients,
2929
_value_and_grad,
3030
should_update_with_optimizers,
@@ -169,35 +169,6 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
169169
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
170170
)
171171

172-
def test_prune_empty_state(self):
173-
state = {
174-
"state": {
175-
"tensor": jnp.array(0),
176-
"nested": {
177-
"empty": {},
178-
"not_empty": jnp.array([]),
179-
},
180-
},
181-
"removed": {
182-
"nested": {
183-
"deep_nested": {},
184-
},
185-
"sibling": {
186-
"deep_nested": {},
187-
},
188-
},
189-
}
190-
expected = {
191-
"state": {
192-
"tensor": jnp.array(0),
193-
"nested": {
194-
"not_empty": jnp.array([]),
195-
},
196-
},
197-
}
198-
actual = _prune_empty(state)
199-
self.assertNestedAllClose(expected, actual)
200-
201172
@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
202173
def test_learner(self, ema_decay: Optional[float], method: str):
203174
learning_rate = config_for_function(schedule.stepwise).set(

axlearn/common/loss_metrics.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
"""Layers for computing training time metrics."""
4+
5+
from axlearn.common.base_layer import BaseLayer
6+
from axlearn.common.utils import Nested, Tensor
7+
8+
9+
class BaseLossMetrics(BaseLayer):
10+
"""A module for computing training time metrics.
11+
12+
See `causal_lm.Model` for an example usage.
13+
"""
14+
15+
def forward(
16+
self,
17+
input_batch: Nested[Tensor],
18+
*,
19+
predict_outputs: Nested[Tensor],
20+
module_outputs: Nested[Tensor],
21+
) -> tuple[Tensor, Nested[Tensor]]:
22+
"""Computes metrics from inputs and predictions.
23+
24+
Args:
25+
input_batch: A mapping from input keys to Tensors.
26+
predict_outputs: Model predictions for computing metrics.
27+
module_outputs: Outputs from the model's invocation context.
28+
29+
Returns:
30+
A tuple (loss, metrics).
31+
"""
32+
raise NotImplementedError(type(self))

axlearn/common/metrics.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from absl import logging
1010

1111
from axlearn.common.config import Configurable
12-
from axlearn.common.module import Module
1312
from axlearn.common.summary import Summary
14-
from axlearn.common.utils import Nested, NestedTensor, Tensor
13+
from axlearn.common.utils import NestedTensor, Tensor
1514

1615

1716
class WeightedScalarValue(Summary):
@@ -51,32 +50,6 @@ def accumulate(self, other: Summary) -> Summary:
5150
return self + other
5251

5352

54-
class BaseLossMetrics(Module):
55-
"""A module for computing training time metrics.
56-
57-
See `causal_lm.Model` for an example usage.
58-
"""
59-
60-
def forward(
61-
self,
62-
input_batch: Nested[Tensor],
63-
*,
64-
predict_outputs: Nested[Tensor],
65-
module_outputs: Nested[Tensor],
66-
) -> tuple[Tensor, Nested[Tensor]]:
67-
"""Computes metrics from inputs and predictions.
68-
69-
Args:
70-
input_batch: A mapping from input keys to Tensors.
71-
predict_outputs: Model predictions for computing metrics.
72-
module_outputs: Outputs from the model's invocation context.
73-
74-
Returns:
75-
A tuple (loss, metrics).
76-
"""
77-
raise NotImplementedError(type(self))
78-
79-
8053
class MetricAccumulator(Configurable):
8154
"""A MetricAccumulator is used during evaluation to accumulate metrics across batches."""
8255

axlearn/common/test_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
complete_partition_spec_tree,
5757
flatten_items,
5858
pop_data_dir,
59+
prune_empty,
5960
prune_tree,
6061
push_data_dir,
6162
set_data_dir,
@@ -194,8 +195,9 @@ def _compute_layer_outputs(
194195
)
195196
# Optionally, test that trees also have the same structure.
196197
if require_same_tree_structure:
197-
ref_structure = jax.tree_util.tree_structure(params_from_ref)
198-
test_structure = jax.tree_util.tree_structure(layer_params)
198+
# Prune empty subtrees so we don't require empty dicts for layers with no params.
199+
ref_structure = jax.tree_util.tree_structure(prune_empty(params_from_ref))
200+
test_structure = jax.tree_util.tree_structure(prune_empty(layer_params))
199201
self.assertEqual(
200202
ref_structure, test_structure, msg=f"\nRef: {ref_structure}\nTest: {test_structure}"
201203
)

axlearn/common/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,3 +1696,19 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]):
16961696
f"Input is expected to contain '{path}'; "
16971697
f"instead, it contains: '{jax.tree_structure(x)}'."
16981698
) from e
1699+
1700+
1701+
def prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]:
1702+
"""Returns a shallow copy of the input tree with empty subtrees pruned.
1703+
1704+
If a tree would be made empty by removal of its subtrees, it will also be pruned.
1705+
This is a shallow copy because leaf nodes (non-dict values) are not deep-copied.
1706+
1707+
Args:
1708+
in_tree: the input tree to be pruned.
1709+
1710+
Returns:
1711+
The pruned copy of the input tree.
1712+
"""
1713+
# Note that falsey values or empty Tensors are not considered empty.
1714+
return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v)

axlearn/common/utils_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
infer_mesh_shape,
7272
input_partition_spec,
7373
match_regex_rules,
74+
prune_empty,
7475
prune_tree,
7576
pytree_children,
7677
replicate_to_local_data,
@@ -780,6 +781,35 @@ def test_sequence_mask(self, lengths, dtype, expected):
780781
expected = jnp.array(expected).astype(dtype if dtype else jnp.int32)
781782
self.assertNestedAllClose(mask, expected)
782783

784+
def test_prune_empty_state(self):
785+
state = {
786+
"state": {
787+
"tensor": jnp.array(0),
788+
"nested": {
789+
"empty": {},
790+
"not_empty": jnp.array([]),
791+
},
792+
},
793+
"removed": {
794+
"nested": {
795+
"deep_nested": {},
796+
},
797+
"sibling": {
798+
"deep_nested": {},
799+
},
800+
},
801+
}
802+
expected = {
803+
"state": {
804+
"tensor": jnp.array(0),
805+
"nested": {
806+
"not_empty": jnp.array([]),
807+
},
808+
},
809+
}
810+
actual = prune_empty(state)
811+
self.assertNestedAllClose(expected, actual)
812+
783813

784814
class SimilarNamesTest(TestCase):
785815
@parameterized.parameters(

0 commit comments

Comments
 (0)