Skip to content

Commit fc761b0

Browse files
author
Mark Lee
authored
Unify init and prefill for attention layers. (#860)
* Unify init and prefill for attention layers. * Fix some types and docstrings.
1 parent bad0f0f commit fc761b0

File tree

11 files changed

+644
-531
lines changed

11 files changed

+644
-531
lines changed

axlearn/common/attention.py

Lines changed: 294 additions & 328 deletions
Large diffs are not rendered by default.

axlearn/common/attention_test.py

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
Nested,
117117
PartitionSpec,
118118
Tensor,
119+
TensorSpec,
119120
VDict,
120121
as_tensor,
121122
flatten_items,
@@ -1472,7 +1473,10 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
14721473
inputs=dict(query=query),
14731474
)
14741475

1475-
cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len)
1476+
cache_state, init_output = layer.init_states(
1477+
time_step=None, query=TensorSpec([batch_size, tgt_len])
1478+
)
1479+
self.assertIsNone(init_output)
14761480
step_querys = []
14771481
step_keys = step_values = None
14781482
for t in range(0, tgt_len, extend_step_len):
@@ -1531,18 +1535,19 @@ def __init__(self, cfg: Config, *, parent: Module):
15311535
qkv_linear = parent.qkv_linear
15321536
state = qkv_linear.initialize_parameters_recursively(jax.random.PRNGKey(0))
15331537

1534-
# Check dtypes from init_states
1535-
cache, _ = F(
1538+
# Check dtypes from init_states.
1539+
(cache, init_output), _ = F(
15361540
qkv_linear,
15371541
prng_key=jax.random.PRNGKey(0),
15381542
state=state,
15391543
inputs=dict(
1540-
target_batch_size=target_batch_size,
1541-
target_max_len=target_max_len,
1544+
time_step=None,
1545+
query=TensorSpec([target_batch_size, target_max_len]),
15421546
),
15431547
method="init_states",
15441548
is_training=False,
15451549
)
1550+
self.assertIsNone(init_output)
15461551
self.assertEqual(cache["key"].dtype, dtype)
15471552
self.assertEqual(cache["value"].dtype, dtype)
15481553

@@ -1562,7 +1567,7 @@ def __init__(self, cfg: Config, *, parent: Module):
15621567
prng_key=jax.random.PRNGKey(0),
15631568
state=state,
15641569
inputs=dict(time_step=time_step, query=query),
1565-
method="prefill_states",
1570+
method="init_states",
15661571
is_training=False,
15671572
)
15681573
self.assertEqual(init_state["key"].dtype, dtype)
@@ -2448,9 +2453,14 @@ def _test_extend_step(
24482453
inputs=inputs,
24492454
)
24502455

2451-
initial_state = layer.init_states(
2452-
target_batch_size=batch_size, target_max_len=tgt_len, kv_state=kv_state
2456+
initial_state, initial_output = layer.init_states(
2457+
time_step=None,
2458+
query=TensorSpec([batch_size, tgt_len]),
2459+
kv_state=kv_state,
2460+
# This is unused for initializing state from scratch.
2461+
attention_logit_biases=None,
24532462
)
2463+
self.assertIsNone(initial_output)
24542464
if kv_state is None:
24552465
for k in ["key", "value"]:
24562466
# Check that the cache dtype is inferred as the layer dtype.
@@ -2619,7 +2629,7 @@ def _test_prefill_states(
26192629
attention_logit_biases=attention_logit_biases,
26202630
return_aux=return_aux,
26212631
),
2622-
method="prefill_states",
2632+
method="init_states",
26232633
)
26242634

26252635
# Check time_step and shapes of state.
@@ -3227,6 +3237,96 @@ def test_multihead_attention_xl(self):
32273237
)
32283238

32293239

3240+
class TransformerAttentionLayerTest(TestCase):
3241+
@parameterized.parameters([False, True])
3242+
def test_forward_vs_extend_step(self, with_source: bool):
3243+
init_prng, target_prng, source_prng = jax.random.split(jax.random.PRNGKey(0), 3)
3244+
3245+
model_dim = 8
3246+
layer_kwargs = dict(target_dim=model_dim, source_dim=model_dim)
3247+
cfg: TransformerAttentionLayer.Config = TransformerAttentionLayer.default_config().set(
3248+
**layer_kwargs
3249+
)
3250+
cfg.attention.set(num_heads=2, mask=causal_mask)
3251+
layer: TransformerAttentionLayer = cfg.set(name="test").instantiate(parent=None)
3252+
layer_params = layer.initialize_parameters_recursively(prng_key=init_prng)
3253+
3254+
batch, decode_len = 2, 6
3255+
target = jax.random.uniform(target_prng, shape=[batch, decode_len, model_dim])
3256+
input_kwargs = {}
3257+
3258+
if with_source:
3259+
input_kwargs.update(
3260+
source=jax.random.uniform(source_prng, shape=[batch, decode_len, model_dim])
3261+
)
3262+
3263+
forward_outputs, _ = F(
3264+
layer,
3265+
inputs=dict(target=jnp.asarray(target), **input_kwargs),
3266+
state=layer_params,
3267+
is_training=True,
3268+
prng_key=jax.random.PRNGKey(0),
3269+
)
3270+
3271+
for start_time_step in (-1, 0, 2, decode_len):
3272+
if start_time_step < 0:
3273+
(cached_states, init_outputs), _ = F(
3274+
layer,
3275+
inputs=dict(
3276+
time_step=None,
3277+
target=TensorSpec(target.shape, target.dtype),
3278+
**input_kwargs,
3279+
),
3280+
state=layer_params,
3281+
is_training=True,
3282+
prng_key=jax.random.PRNGKey(0),
3283+
method="init_states",
3284+
)
3285+
self.assertIsNone(init_outputs)
3286+
data = jnp.zeros([batch, decode_len, model_dim])
3287+
start_time_step = 0
3288+
else:
3289+
(cached_states, prefill_outputs), _ = F(
3290+
layer,
3291+
inputs=dict(
3292+
time_step=jnp.array([start_time_step] * batch, dtype=jnp.int32),
3293+
target=target,
3294+
**input_kwargs,
3295+
),
3296+
state=layer_params,
3297+
is_training=True,
3298+
prng_key=jax.random.PRNGKey(0),
3299+
method="init_states",
3300+
)
3301+
data = prefill_outputs.data
3302+
3303+
data = jnp.einsum("btd->tbd", data)
3304+
3305+
for time_step in range(start_time_step, decode_len):
3306+
extend_kwargs = {}
3307+
for k, v in input_kwargs.items():
3308+
extend_kwargs[k] = jnp.asarray(v[:, time_step : time_step + 1, :])
3309+
3310+
(cached_states, extend_outputs), _ = F(
3311+
layer,
3312+
inputs=dict(
3313+
target=jnp.asarray(target[:, time_step : time_step + 1, :]),
3314+
cached_states=cached_states,
3315+
**extend_kwargs,
3316+
),
3317+
state=layer_params,
3318+
is_training=True,
3319+
prng_key=jax.random.PRNGKey(0),
3320+
method="extend_step",
3321+
)
3322+
data = data.at[time_step].set(jnp.squeeze(extend_outputs.data, axis=1))
3323+
3324+
data = jnp.einsum("tbd->btd", data)
3325+
3326+
# Prefill + extend_step == forward.
3327+
assert_allclose(forward_outputs.data, data)
3328+
3329+
32303330
class TransformerFeedForwardLayerTest(TestCase):
32313331
@parameterized.parameters(
32323332
dict(rms_norm_summary=[]),
@@ -3392,20 +3492,21 @@ def _test_forward_vs_extend_step(
33923492
for start_time_step in (-1, 0, 2, tgt_len):
33933493
if start_time_step > tgt_len:
33943494
continue
3395-
print(f"start_time_step={start_time_step}")
3495+
print(f"start_time_step={start_time_step} layer={type(layer)}")
33963496
if start_time_step < 0:
3397-
cached_states, _ = F(
3497+
(cached_states, init_outputs), _ = F(
33983498
layer,
33993499
inputs=dict(
3400-
target_batch_size=batch_size,
3401-
target_max_len=tgt_len,
3500+
time_step=None,
3501+
data=TensorSpec([batch_size, tgt_len]),
34023502
**input_kwargs,
34033503
),
34043504
state=layer_params,
34053505
is_training=True,
34063506
prng_key=jax.random.PRNGKey(0),
34073507
method="init_states",
34083508
)
3509+
self.assertIsNone(init_outputs)
34093510
decoder_output = jnp.zeros_like(target)
34103511
start_time_step = 0
34113512
else:
@@ -3419,7 +3520,7 @@ def _test_forward_vs_extend_step(
34193520
state=layer_params,
34203521
is_training=True,
34213522
prng_key=jax.random.PRNGKey(0),
3422-
method="prefill_states",
3523+
method="init_states",
34233524
)
34243525
decoder_output = prefill_outputs.data
34253526
# Transpose to [tgt_len, batch_size, model_dim].
@@ -3850,7 +3951,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38503951
batch_size, src_len, tgt_len = 10, 4, 6
38513952
num_dec_layers, model_dim, num_heads = 3, 16, 4
38523953

3853-
cfg = transformer_type.default_config().set(
3954+
cfg: BaseStackedTransformerLayer.Config = transformer_type.default_config().set(
38543955
name="test",
38553956
input_dim=model_dim,
38563957
num_layers=num_dec_layers,
@@ -3872,7 +3973,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38723973
layer_cfg.feed_forward.hidden_dim = model_dim * 4
38733974

38743975
# Instantiate transformer stack.
3875-
layer = cfg.instantiate(parent=None)
3976+
layer: BaseStackedTransformerLayer = cfg.instantiate(parent=None)
38763977
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123))
38773978

38783979
target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim])
@@ -3897,7 +3998,11 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38973998
is_training=False,
38983999
prng_key=jax.random.PRNGKey(0),
38994000
)
3900-
initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len)
4001+
initial_state, initial_output = layer.init_states(
4002+
time_step=None,
4003+
data=TensorSpec([batch_size, tgt_len]),
4004+
)
4005+
self.assertIsNone(initial_output)
39014006
inputs = dict(
39024007
cached_states=initial_state, cross_attention_data=source, return_aux=return_aux
39034008
)
@@ -4036,7 +4141,7 @@ def test_transformer_prefill_states(self, transformer_type, layer_type):
40364141
cross_attention_logit_biases=cross_attention_logit_biases,
40374142
return_aux=return_aux,
40384143
),
4039-
method="prefill_states",
4144+
method="init_states",
40404145
)
40414146

40424147
# Zero-out outputs starting from initial time_step, and test that we can recover the full

axlearn/common/decoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
current_context,
4848
new_output_collection,
4949
)
50-
from axlearn.common.utils import Nested, NestedTensor, with_sharding_constraint
50+
from axlearn.common.utils import Nested, NestedTensor, TensorSpec, with_sharding_constraint
5151

5252

5353
# TODO(markblee): Remove this when we have a better solution at the decoding loop level.
@@ -492,7 +492,7 @@ def _forward_for_mode(
492492
assert cached_states is not None
493493
if input_segment_ids is not None:
494494
raise ValueError("input_segment_ids is not supported in INIT_STATES.")
495-
transformer_state, x = self.transformer.prefill_states(
495+
transformer_state, x = self.transformer.init_states(
496496
time_step=cached_states["transformer_state"],
497497
data=x,
498498
self_attention_logit_biases=self_attention_logit_biases,
@@ -584,10 +584,12 @@ def forward(
584584
def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTensor:
585585
"""See `BaseDecoder.init_states` for details."""
586586
cfg: Decoder.Config = self.config
587+
init_state, _ = self.transformer.init_states(
588+
time_step=None,
589+
data=TensorSpec([batch_size, max_sequence_length, cfg.dim]),
590+
)
587591
return dict(
588-
transformer_state=self.transformer.init_states(
589-
target_batch_size=batch_size, target_max_len=max_sequence_length
590-
),
592+
transformer_state=init_state,
591593
input_ids=jnp.full(
592594
(batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32
593595
),

axlearn/common/encoder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright © 2023 Apple Inc.
22

33
"""Encoder layers."""
4+
45
import math
56
from typing import Optional
67

@@ -20,7 +21,7 @@
2021
from axlearn.common.embedding import TransformerTextEmbeddings
2122
from axlearn.common.layers import BaseClassificationHead, set_dropout_rate_recursively
2223
from axlearn.common.module import Module, Tensor, child_context
23-
from axlearn.common.utils import NestedTensor
24+
from axlearn.common.utils import NestedTensor, TensorSpec
2425

2526

2627
class Encoder(BaseLayer):
@@ -167,12 +168,15 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTen
167168
Returns:
168169
The cache as a `NestedTensor` with key and value initialized.
169170
"""
171+
cfg: CausalEncoder.Config = self.config
172+
init_state, _ = self.transformer.init_states(
173+
time_step=None,
174+
data=TensorSpec([batch_size, max_sequence_length, cfg.dim]),
175+
)
170176
return dict(
171-
transformer_state=self.transformer.init_states(
172-
target_batch_size=batch_size, target_max_len=max_sequence_length
173-
),
177+
transformer_state=init_state,
174178
input_ids=jnp.full(
175-
(batch_size, max_sequence_length), self.config.pad_token_id, dtype=jnp.int32
179+
(batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32
176180
),
177181
time_step=jnp.zeros(batch_size, dtype=jnp.int32),
178182
)
@@ -279,7 +283,7 @@ def prefill_states(
279283
# Note: this follows `Decoder.prefill_states` closely. Refer to that method for details.
280284
# TODO(markblee): Possibly consolidate some of this with decoder.
281285
x = self.emb(input_ids, token_type_ids=token_type_ids, positions=None)
282-
transformer_state, x = self.transformer.prefill_states(
286+
transformer_state, x = self.transformer.init_states(
283287
time_step=time_step,
284288
data=x,
285289
self_attention_logit_biases=self.compute_attention_logit_biases(input_ids),

axlearn/common/flash_attention/layer_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from axlearn.common.module import Module
3939
from axlearn.common.module import functional as F
4040
from axlearn.common.test_utils import TestCase, is_supported_mesh_shape
41+
from axlearn.common.utils import TensorSpec
4142

4243

4344
def _fake_inputs(
@@ -650,12 +651,20 @@ def test_extend_step(
650651
)
651652

652653
# Prepare initial states.
653-
initial_state = test_layer.init_states(
654-
target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state
654+
initial_state, initial_output = test_layer.init_states(
655+
time_step=None,
656+
query=TensorSpec([batch, seq_len]),
657+
kv_state=kv_state,
658+
attention_logit_biases=None,
655659
)
656-
ref_initial_state = ref_layer.init_states(
657-
target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state
660+
ref_initial_state, ref_inital_output = ref_layer.init_states(
661+
time_step=None,
662+
query=TensorSpec([batch, seq_len]),
663+
kv_state=kv_state,
664+
attention_logit_biases=None,
658665
)
666+
self.assertIsNone(initial_output)
667+
self.assertIsNone(ref_inital_output)
659668
for k in ["key", "value"]:
660669
self.assertEqual(ref_initial_state["i_proj"][k].dtype, dtype)
661670
self.assertEqual(initial_state["i_proj"][k].dtype, dtype)

axlearn/common/lora_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from axlearn.common.module import functional as F
2727
from axlearn.common.param_converter import as_torch_tensor
2828
from axlearn.common.test_utils import TestCase, assert_allclose
29-
from axlearn.common.utils import Tensor
29+
from axlearn.common.utils import Tensor, TensorSpec
3030

3131

3232
class LoraLinearTest(TestCase):
@@ -233,9 +233,11 @@ def test_extend_step(self, layer):
233233
q_proj, k_proj, v_proj = outputs
234234
forward_outputs = jnp.stack([q_proj, k_proj, v_proj])
235235

236-
initial_cache_state = layer.init_states(
237-
target_batch_size=batch_size, target_max_len=seq_len
236+
initial_cache_state, init_output = layer.init_states(
237+
time_step=None,
238+
query=TensorSpec([batch_size, seq_len]),
238239
)
240+
self.assertIsNone(init_output)
239241

240242
decoder_inputs = dict(cached_states=initial_cache_state)
241243
decoder_outputs = jnp.zeros(shape=[seq_len, 3, batch_size, num_heads, per_head_dim])
@@ -305,7 +307,7 @@ def test_prefill_states(self):
305307
is_training=False,
306308
prng_key=jax.random.PRNGKey(456),
307309
inputs=dict(time_step=time_step, query=inputs),
308-
method="prefill_states",
310+
method="init_states",
309311
)
310312
time_step_mask = jnp.arange(seq_len) < time_step[:, None]
311313
# [batch, tgt_len, num_heads, per_head_dim].

0 commit comments

Comments
 (0)