Skip to content

Workaround module outputs being dropped. #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
from typing import Callable, Optional, Union

import jax
from absl import logging
from jax import numpy as jnp
from jax._src.mesh import thread_resources
Expand Down Expand Up @@ -184,15 +185,26 @@ def forward(
# Collect aux_loss from all leaves in the invocation hierarchy, not just current ctx.
ctx = self.get_invocation_context()
while ctx.parent:
# TODO(markblee): Fix learner dropping module outputs in forward.
if isinstance(ctx.module, BaseModel):
break
ctx = ctx.parent
module_outputs = ctx.get_module_outputs()

logging.info("Context: %s Module outputs: %s", ctx, jax.tree_structure(module_outputs))
accumulation = []
for k, _ in flatten_items(module_outputs):
if re.fullmatch(regex, k):
logging.info("Aux loss found at %s", k)
else:
logging.info("Aux loss not found at %s", k)
accumulation = list(
v.mean() for k, v in flatten_items(module_outputs) if re.fullmatch(regex, k)
)
if accumulation:
aux_loss = sum(accumulation) / len(accumulation)
else:
logging.warning("Aux loss not found: %s", cfg.aux_loss_regex)
aux_loss = 0.0

self.add_summary("aux_loss", WeightedScalar(aux_loss, num_targets))
Expand Down
133 changes: 118 additions & 15 deletions axlearn/common/causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Tests autoregressive models."""

from functools import partial
from typing import cast

import jax
import jax.random
Expand All @@ -15,23 +16,31 @@

from axlearn.common import causal_lm, utils
from axlearn.common.attention import (
BaseStackedTransformerLayer,
CausalAttentionLogitBiasLayer,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerFeedForwardLayer,
)
from axlearn.common.config import config_for_function
from axlearn.common.learner import Learner
from axlearn.common.loss import cross_entropy
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import (
InvocationContext,
OutputCollection,
child_context,
functional,
new_output_collection,
set_current_context,
)
from axlearn.common.optimizer_base import OptParam
from axlearn.common.optimizers import sgd_optimizer
from axlearn.common.param_converter import as_torch_tensor
from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.torch_utils import parameters_from_torch_layer
from axlearn.common.update_transformation import ForwardBackwardOutputs, ForwardOutputs
from axlearn.common.utils import Tensor


Expand Down Expand Up @@ -431,21 +440,19 @@ def forward(self, inputs: Tensor) -> Tensor:


class ModelAuxLossTest(TestCase):
@parameterized.product(
aux_loss_regex=(None, ".*/aux_loss", ".*/apple"),
stack_cfg=(
RepeatedTransformerLayer.default_config(),
StackedTransformerLayer.default_config(),
),
use_aux_layer=(False, True),
)
def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer):
batch_size, seq_len, vocab_size = 3, 10, 10
hidden_dim = 8
num_layers = 6
def _model_config(
self,
*,
stack_cfg: BaseStackedTransformerLayer.Config,
hidden_dim: int,
vocab_size: int,
seq_len: int,
aux_loss_regex: str,
use_aux_layer: bool,
) -> causal_lm.Model.Config:
decoder_cfg = causal_lm.gpt_decoder_config(
stack_cfg=stack_cfg,
num_layers=num_layers,
num_layers=6,
hidden_dim=hidden_dim,
num_heads=4,
vocab_size=vocab_size,
Expand All @@ -457,11 +464,31 @@ def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer):
decoder_cfg.transformer.layer.feed_forward = (
DummyFeedForwardWithAuxLoss.default_config().set(hidden_dim=4 * hidden_dim)
)
model_cfg: causal_lm.Model.Config = causal_lm.Model.default_config().set(
return causal_lm.Model.default_config().set(
decoder=decoder_cfg,
name="metrics_test",
metrics=causal_lm.metrics_config(aux_loss_regex=aux_loss_regex),
)

@parameterized.product(
aux_loss_regex=(None, ".*/aux_loss", ".*/apple"),
stack_cfg=(
RepeatedTransformerLayer.default_config(),
StackedTransformerLayer.default_config(),
),
use_aux_layer=(False, True),
)
def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer):
batch_size, seq_len, vocab_size = 3, 10, 10
hidden_dim = 8
model_cfg = self._model_config(
stack_cfg=stack_cfg,
hidden_dim=hidden_dim,
vocab_size=vocab_size,
seq_len=seq_len,
aux_loss_regex=aux_loss_regex,
use_aux_layer=use_aux_layer,
)
model = model_cfg.instantiate(parent=None)
prng_key, init_key = jax.random.split(jax.random.PRNGKey(123))
model_params = model.initialize_parameters_recursively(init_key)
Expand Down Expand Up @@ -494,6 +521,82 @@ def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer):
self.assertNotIn("aux_loss", aux)
self.assertEqual(aux["metrics"]["cross_entropy"], loss)

@parameterized.product(
stack_cfg=(
RepeatedTransformerLayer.default_config(),
StackedTransformerLayer.default_config(),
),
)
def test_aux_loss_learner(self, stack_cfg):
batch_size, seq_len, vocab_size = 3, 10, 10
hidden_dim = 8
model_cfg = self._model_config(
stack_cfg=stack_cfg,
hidden_dim=hidden_dim,
vocab_size=vocab_size,
seq_len=seq_len,
aux_loss_regex=".*/aux_loss",
use_aux_layer=True,
)
model: causal_lm.Model = model_cfg.set(name="model").instantiate(parent=None)
learner_cfg = Learner.default_config().set(
optimizer=config_for_function(sgd_optimizer).set(
learning_rate=0.1, decouple_weight_decay=True, weight_decay=1.0
)
)
learner = learner_cfg.set(name="learner").instantiate(parent=None)
init_key, forward_key = jax.random.split(jax.random.PRNGKey(123), num=2)

model_cfg = causal_lm.Model.default_config()
params = model.initialize_parameters_recursively(init_key)
opt_params = jax.tree.map(
lambda v: OptParam(value=v, factorization_spec=None, weight_decay_scale=None), params
)
state = learner.init(model_params=opt_params)

input_ids = jax.random.randint(
jax.random.PRNGKey(123), shape=[batch_size, seq_len], minval=0, maxval=vocab_size
)
target_labels = jax.random.randint(
jax.random.PRNGKey(123), shape=[batch_size, seq_len], minval=-1, maxval=vocab_size
)
input_batch = dict(input_ids=input_ids, target_labels=target_labels)

def loss_fn(model_params, inputs):
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

inputs = dict(
fn=loss_fn,
inputs=dict(forward_key=forward_key, input_batch=input_batch),
opt_params=opt_params,
)
outputs, _ = functional(
learner,
method="forward_and_backward",
is_training=True,
prng_key=forward_key,
state=state,
inputs=inputs,
)
outputs = cast(ForwardBackwardOutputs, outputs)
output_collection: OutputCollection = outputs.forward_outputs.output_collection
summaries: dict[str, WeightedScalar] = output_collection.summaries
self.assertIn("aux_loss", summaries)
self.assertEqual(summaries["aux_loss"].mean, 1.0)
self.assertEqual(
summaries["cross_entropy_loss"].mean + summaries["aux_loss"].mean,
outputs.forward_outputs.loss,
)


if __name__ == "__main__":
with utils.numeric_checks(True):
Expand Down
Loading