Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External KV input for _update_layer_kwargs #1025

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3603,6 +3603,7 @@ def _forward_for_mode(
"""
all_layer_outputs = []
all_layer_states = []
external_self_attention_kv_state = layer_kwargs.get("self_attention_kv_state")

# True iff we are initializing an empty cache (i.e., not prefilling).
cache_init = mode == ForwardMode.INIT_STATES and cached_states is None
Expand All @@ -3612,7 +3613,11 @@ def _forward_for_mode(
if self._update_data is not None:
data = self._update_data(data, all_layer_outputs)
# TODO(markblee): Consider folding into _update_data.
self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs)
self._update_layer_kwargs(
layer_kwargs,
all_layer_outputs=all_layer_outputs,
external_self_attention_kv_state=external_self_attention_kv_state,
)

if mode == ForwardMode.FORWARD:
layer_states, layer_outputs = None, layer(data, **layer_kwargs)
Expand Down Expand Up @@ -3668,6 +3673,7 @@ def _update_layer_kwargs(
layer_kwargs: dict[str, Any],
*,
all_layer_outputs: list[BaseTransformerLayer.Output],
external_self_attention_kv_state: Optional[KVState] = None,
):
"""Updates `layer_kwargs` using other args.

Expand All @@ -3678,6 +3684,8 @@ def _update_layer_kwargs(
layer_kwargs: a dictionary of arguments that can be used by individual layers.
all_layer_outputs: a list of BaseTransformerLayer.Output that is appended with
the output of each constituent layer in the stack.
external_self_attention_kv_state: A KVState that this function processes
to populate (if needed) the self_attention_kv_state within `layer_kwargs`.
"""
pass # Do nothing by default.

Expand Down
48 changes: 48 additions & 0 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4183,7 +4183,10 @@ def _update_layer_kwargs(
layer_kwargs: dict[str, Any],
*,
all_layer_outputs: list[BaseTransformerLayer.Output],
external_self_attention_kv_state: Optional[KVState] = None,
):
del external_self_attention_kv_state

layer_index = len(all_layer_outputs)
if layer_index == 1:
layer_kwargs["self_attention_kv_state"] = all_layer_outputs[-1].self_attention_kv_state
Expand Down Expand Up @@ -4586,6 +4589,51 @@ def test_skip_connection(self):
0.0,
)

def test_passthrough_update_layer_kwargs(self):
num_heads = 2
input_dim = 4
hidden_dim = 8
num_layers = 3

cfg = StackedTransformerLayer.default_config().set(name="test")
cfg.input_dim = input_dim
cfg.num_layers = num_layers

transformer_cfg = TransformerLayer.default_config()
transformer_cfg.self_attention.attention.num_heads = num_heads
transformer_cfg.feed_forward.hidden_dim = hidden_dim
cfg.layer = transformer_cfg

layer: StackedTransformerLayer = cfg.instantiate(parent=None)
state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123))

input_all_layer_outputs = [BaseTransformerLayer.Output(data=jnp.ones([2, 3]))]
expected_all_layer_outputs = [BaseTransformerLayer.Output(data=jnp.ones([2, 3]))]
k_proj = jnp.zeros([3, 3])
v_proj = jnp.ones([3, 3])
input_self_attention_kv_state = KVState(k_proj=k_proj, v_proj=v_proj)
expected_self_attention_kv_state = KVState(k_proj=k_proj, v_proj=v_proj)
F(
layer,
prng_key=jax.random.PRNGKey(0),
state=state,
inputs=dict(
layer_kwargs={},
all_layer_outputs=[],
external_self_attention_kv_state=input_self_attention_kv_state,
),
method="_update_layer_kwargs",
is_training=True,
)
self.assertNestedAllClose(
input_all_layer_outputs,
expected_all_layer_outputs,
)
self.assertNestedAllClose(
input_self_attention_kv_state,
expected_self_attention_kv_state,
)

def test_update_layer_kwargs(self):
batch_size = 2
seq_len = 6
Expand Down
Loading