Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow metrics layers to have state. #978

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from axlearn.common.layers import LayerNorm
from axlearn.common.logit_modifiers import LogitsToLogitsFn
from axlearn.common.loss import cross_entropy
from axlearn.common.metrics import BaseLossMetrics, WeightedScalar
from axlearn.common.loss_metrics import BaseLossMetrics
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import Module, NestedTensor, Tensor, child_context
from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer
from axlearn.common.utils import (
Expand All @@ -48,6 +49,7 @@ def layer_norm_config(eps=1e-5):
return LayerNorm.default_config().set(eps=eps)


# TODO(markblee): Move these to `axlearn.common.loss_metrics` and update golden configs.
class CrossEntropyLossMetrics(BaseLossMetrics):
"""Computes cross entropy loss and related training summaries."""

Expand Down
3 changes: 2 additions & 1 deletion axlearn/common/causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from axlearn.common.config import config_for_function
from axlearn.common.learner import Learner
from axlearn.common.loss import cross_entropy
from axlearn.common.metrics import BaseLossMetrics, MetricAccumulator, WeightedScalar
from axlearn.common.loss_metrics import BaseLossMetrics
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import (
InvocationContext,
OutputCollection,
Expand Down
21 changes: 3 additions & 18 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Applying updates on non-differentiable params such as batch norm stats;
- Maintaining Polyak averages of model params (if enabled).
"""

from __future__ import annotations

import dataclasses
Expand Down Expand Up @@ -47,7 +48,7 @@
Tensor,
flatten_items,
match_regex_rules,
prune_tree,
prune_empty,
register_per_param_settings,
tree_paths,
)
Expand Down Expand Up @@ -91,22 +92,6 @@ def should_apply_state_updates(update_type: UpdateType) -> bool:
return update_type in (UpdateType.STATE_UPDATES, UpdateType.ALL_UPDATES)


def _prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]:
"""Returns a shallow copy of the input tree with empty subtrees pruned.

If a tree would be made empty by removal of its subtrees, it will also be pruned.
This is a shallow copy because leaf nodes (non-dict values) are not deep-copied.

Args:
in_tree: the input tree to be pruned.

Returns:
The pruned copy of the input tree.
"""
# Note that falsey values or empty Tensors are not considered empty.
return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v)


class BaseLearner(LearnerModule):
"""The base class of a learner."""

Expand Down Expand Up @@ -309,7 +294,7 @@ def _compute_updated_params(
updated_model_params = optax.apply_updates(
jax.tree.map(lambda op: op.value, opt_params), parameter_updates
)
state_updates = _prune_empty(state_updates)
state_updates = prune_empty(state_updates)
apply_state_updates = jax.tree.map(
should_apply_state_updates,
self._update_types(state_updates),
Expand Down
31 changes: 1 addition & 30 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
"""Tests learner."""

import copy
import re
from numbers import Number
Expand All @@ -24,7 +25,6 @@
Learner,
UpdateType,
_apply_updates,
_prune_empty,
_split_gradients,
_value_and_grad,
should_update_with_optimizers,
Expand Down Expand Up @@ -169,35 +169,6 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
)

def test_prune_empty_state(self):
state = {
"state": {
"tensor": jnp.array(0),
"nested": {
"empty": {},
"not_empty": jnp.array([]),
},
},
"removed": {
"nested": {
"deep_nested": {},
},
"sibling": {
"deep_nested": {},
},
},
}
expected = {
"state": {
"tensor": jnp.array(0),
"nested": {
"not_empty": jnp.array([]),
},
},
}
actual = _prune_empty(state)
self.assertNestedAllClose(expected, actual)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
def test_learner(self, ema_decay: Optional[float], method: str):
learning_rate = config_for_function(schedule.stepwise).set(
Expand Down
32 changes: 32 additions & 0 deletions axlearn/common/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright © 2025 Apple Inc.

"""Layers for computing training time metrics."""

from axlearn.common.base_layer import BaseLayer
from axlearn.common.utils import Nested, Tensor


class BaseLossMetrics(BaseLayer):
"""A module for computing training time metrics.

See `causal_lm.Model` for an example usage.
"""

def forward(
self,
input_batch: Nested[Tensor],
*,
predict_outputs: Nested[Tensor],
module_outputs: Nested[Tensor],
) -> tuple[Tensor, Nested[Tensor]]:
"""Computes metrics from inputs and predictions.

Args:
input_batch: A mapping from input keys to Tensors.
predict_outputs: Model predictions for computing metrics.
module_outputs: Outputs from the model's invocation context.

Returns:
A tuple (loss, metrics).
"""
raise NotImplementedError(type(self))
29 changes: 1 addition & 28 deletions axlearn/common/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from absl import logging

from axlearn.common.config import Configurable
from axlearn.common.module import Module
from axlearn.common.summary import Summary
from axlearn.common.utils import Nested, NestedTensor, Tensor
from axlearn.common.utils import NestedTensor, Tensor


class WeightedScalarValue(Summary):
Expand Down Expand Up @@ -51,32 +50,6 @@ def accumulate(self, other: Summary) -> Summary:
return self + other


class BaseLossMetrics(Module):
"""A module for computing training time metrics.

See `causal_lm.Model` for an example usage.
"""

def forward(
self,
input_batch: Nested[Tensor],
*,
predict_outputs: Nested[Tensor],
module_outputs: Nested[Tensor],
) -> tuple[Tensor, Nested[Tensor]]:
"""Computes metrics from inputs and predictions.

Args:
input_batch: A mapping from input keys to Tensors.
predict_outputs: Model predictions for computing metrics.
module_outputs: Outputs from the model's invocation context.

Returns:
A tuple (loss, metrics).
"""
raise NotImplementedError(type(self))


class MetricAccumulator(Configurable):
"""A MetricAccumulator is used during evaluation to accumulate metrics across batches."""

Expand Down
6 changes: 4 additions & 2 deletions axlearn/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
complete_partition_spec_tree,
flatten_items,
pop_data_dir,
prune_empty,
prune_tree,
push_data_dir,
set_data_dir,
Expand Down Expand Up @@ -194,8 +195,9 @@ def _compute_layer_outputs(
)
# Optionally, test that trees also have the same structure.
if require_same_tree_structure:
ref_structure = jax.tree_util.tree_structure(params_from_ref)
test_structure = jax.tree_util.tree_structure(layer_params)
# Prune empty subtrees so we don't require empty dicts for layers with no params.
ref_structure = jax.tree_util.tree_structure(prune_empty(params_from_ref))
test_structure = jax.tree_util.tree_structure(prune_empty(layer_params))
self.assertEqual(
ref_structure, test_structure, msg=f"\nRef: {ref_structure}\nTest: {test_structure}"
)
Expand Down
16 changes: 16 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,19 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]):
f"Input is expected to contain '{path}'; "
f"instead, it contains: '{jax.tree_structure(x)}'."
) from e


def prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]:
"""Returns a shallow copy of the input tree with empty subtrees pruned.

If a tree would be made empty by removal of its subtrees, it will also be pruned.
This is a shallow copy because leaf nodes (non-dict values) are not deep-copied.

Args:
in_tree: the input tree to be pruned.

Returns:
The pruned copy of the input tree.
"""
# Note that falsey values or empty Tensors are not considered empty.
return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v)
30 changes: 30 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
infer_mesh_shape,
input_partition_spec,
match_regex_rules,
prune_empty,
prune_tree,
pytree_children,
replicate_to_local_data,
Expand Down Expand Up @@ -780,6 +781,35 @@ def test_sequence_mask(self, lengths, dtype, expected):
expected = jnp.array(expected).astype(dtype if dtype else jnp.int32)
self.assertNestedAllClose(mask, expected)

def test_prune_empty_state(self):
state = {
"state": {
"tensor": jnp.array(0),
"nested": {
"empty": {},
"not_empty": jnp.array([]),
},
},
"removed": {
"nested": {
"deep_nested": {},
},
"sibling": {
"deep_nested": {},
},
},
}
expected = {
"state": {
"tensor": jnp.array(0),
"nested": {
"not_empty": jnp.array([]),
},
},
}
actual = prune_empty(state)
self.assertNestedAllClose(expected, actual)


class SimilarNamesTest(TestCase):
@parameterized.parameters(
Expand Down