Skip to content

Commit f8d2c66

Browse files
authored
External KV input for _update_layer_kwargs (apple#1025)
1 parent a3bf5e2 commit f8d2c66

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

axlearn/common/attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3603,6 +3603,7 @@ def _forward_for_mode(
36033603
"""
36043604
all_layer_outputs = []
36053605
all_layer_states = []
3606+
external_self_attention_kv_state = layer_kwargs.get("self_attention_kv_state")
36063607

36073608
# True iff we are initializing an empty cache (i.e., not prefilling).
36083609
cache_init = mode == ForwardMode.INIT_STATES and cached_states is None
@@ -3612,7 +3613,11 @@ def _forward_for_mode(
36123613
if self._update_data is not None:
36133614
data = self._update_data(data, all_layer_outputs)
36143615
# TODO(markblee): Consider folding into _update_data.
3615-
self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs)
3616+
self._update_layer_kwargs(
3617+
layer_kwargs,
3618+
all_layer_outputs=all_layer_outputs,
3619+
external_self_attention_kv_state=external_self_attention_kv_state,
3620+
)
36163621

36173622
if mode == ForwardMode.FORWARD:
36183623
layer_states, layer_outputs = None, layer(data, **layer_kwargs)
@@ -3668,6 +3673,7 @@ def _update_layer_kwargs(
36683673
layer_kwargs: dict[str, Any],
36693674
*,
36703675
all_layer_outputs: list[BaseTransformerLayer.Output],
3676+
external_self_attention_kv_state: Optional[KVState] = None,
36713677
):
36723678
"""Updates `layer_kwargs` using other args.
36733679
@@ -3678,6 +3684,8 @@ def _update_layer_kwargs(
36783684
layer_kwargs: a dictionary of arguments that can be used by individual layers.
36793685
all_layer_outputs: a list of BaseTransformerLayer.Output that is appended with
36803686
the output of each constituent layer in the stack.
3687+
external_self_attention_kv_state: A KVState that this function processes
3688+
to populate (if needed) the self_attention_kv_state within `layer_kwargs`.
36813689
"""
36823690
pass # Do nothing by default.
36833691

axlearn/common/attention_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4183,7 +4183,10 @@ def _update_layer_kwargs(
41834183
layer_kwargs: dict[str, Any],
41844184
*,
41854185
all_layer_outputs: list[BaseTransformerLayer.Output],
4186+
external_self_attention_kv_state: Optional[KVState] = None,
41864187
):
4188+
del external_self_attention_kv_state
4189+
41874190
layer_index = len(all_layer_outputs)
41884191
if layer_index == 1:
41894192
layer_kwargs["self_attention_kv_state"] = all_layer_outputs[-1].self_attention_kv_state
@@ -4586,6 +4589,51 @@ def test_skip_connection(self):
45864589
0.0,
45874590
)
45884591

4592+
def test_passthrough_update_layer_kwargs(self):
4593+
num_heads = 2
4594+
input_dim = 4
4595+
hidden_dim = 8
4596+
num_layers = 3
4597+
4598+
cfg = StackedTransformerLayer.default_config().set(name="test")
4599+
cfg.input_dim = input_dim
4600+
cfg.num_layers = num_layers
4601+
4602+
transformer_cfg = TransformerLayer.default_config()
4603+
transformer_cfg.self_attention.attention.num_heads = num_heads
4604+
transformer_cfg.feed_forward.hidden_dim = hidden_dim
4605+
cfg.layer = transformer_cfg
4606+
4607+
layer: StackedTransformerLayer = cfg.instantiate(parent=None)
4608+
state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123))
4609+
4610+
input_all_layer_outputs = [BaseTransformerLayer.Output(data=jnp.ones([2, 3]))]
4611+
expected_all_layer_outputs = [BaseTransformerLayer.Output(data=jnp.ones([2, 3]))]
4612+
k_proj = jnp.zeros([3, 3])
4613+
v_proj = jnp.ones([3, 3])
4614+
input_self_attention_kv_state = KVState(k_proj=k_proj, v_proj=v_proj)
4615+
expected_self_attention_kv_state = KVState(k_proj=k_proj, v_proj=v_proj)
4616+
F(
4617+
layer,
4618+
prng_key=jax.random.PRNGKey(0),
4619+
state=state,
4620+
inputs=dict(
4621+
layer_kwargs={},
4622+
all_layer_outputs=[],
4623+
external_self_attention_kv_state=input_self_attention_kv_state,
4624+
),
4625+
method="_update_layer_kwargs",
4626+
is_training=True,
4627+
)
4628+
self.assertNestedAllClose(
4629+
input_all_layer_outputs,
4630+
expected_all_layer_outputs,
4631+
)
4632+
self.assertNestedAllClose(
4633+
input_self_attention_kv_state,
4634+
expected_self_attention_kv_state,
4635+
)
4636+
45894637
def test_update_layer_kwargs(self):
45904638
batch_size = 2
45914639
seq_len = 6

0 commit comments

Comments
 (0)