Skip to content

[Core] More-efficient cross-attention parallel QKV computation #7448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c724b2a
wip cross qkv parallel linear
afeldman-nm Aug 10, 2024
94b78a9
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 10, 2024
62e9f45
an approach that only works for unquantized linear
afeldman-nm Aug 10, 2024
9f9350d
working QCrossKVParallelLinear for unquantizedlinear case
afeldman-nm Aug 10, 2024
2bb4c3a
removed unnecessary comment
afeldman-nm Aug 11, 2024
158a148
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 11, 2024
02b62cd
remember shard shapes during weight loading
afeldman-nm Aug 11, 2024
5952bc3
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 12, 2024
659018c
reorganized
afeldman-nm Aug 12, 2024
94e5e37
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 12, 2024
50c9696
Merge branch 'infra_enc_dec_cross' into infra_enc_dec_cross_infer
afeldman-nm Aug 12, 2024
427468f
modified test infra
afeldman-nm Aug 12, 2024
3ed28b1
test
afeldman-nm Aug 12, 2024
284eb05
typo
afeldman-nm Aug 12, 2024
0b2ee5c
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 12, 2024
a78f42b
SOW
afeldman-nm Aug 12, 2024
6888228
removing tests
afeldman-nm Aug 12, 2024
5057726
formatting
afeldman-nm Aug 12, 2024
ed2c73a
Caching; refactoring; formatting
afeldman-nm Aug 12, 2024
761d34d
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 12, 2024
efc5f75
small fix
afeldman-nm Aug 12, 2024
d536085
comments
afeldman-nm Aug 12, 2024
af6dc00
slight refactor
afeldman-nm Aug 12, 2024
de054e3
changes
afeldman-nm Aug 12, 2024
fab5773
test
afeldman-nm Aug 12, 2024
9a3b5ec
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 13, 2024
aaf920b
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 16, 2024
c299520
Merge branch 'main' into infra_enc_dec_cross
afeldman-nm Aug 20, 2024
585a9c8
refactor
afeldman-nm Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,226 @@ def weight_loader(self,
param_data.copy_(loaded_weight)


class _WeightWrapper:
'''
Wrapper for weight matrices. Helper class for
:class:`QCrossKVParallelLinear`.

Generally speaking, the quantized & unquantized linear method
implementations look like the following:

```
some_linear_function(x, layer.weight, bias)
```

i.e. they expect the `layer` argument to be a class with a `weight` member,
mirroring :class:`Linear`.

:class:`QCrossKVParallelLinear` wraps W_Q and [W_K W_V]
(two views of `self.weight`) in :class:`_WeightWrapper` instances, which
can both be passed to the linear method as the `layer` argument.
'''

def __init__(
self,
weight: torch.Tensor,
) -> None:
self.weight = torch.nn.Parameter(weight)


class QCrossKVParallelLinear(QKVParallelLinear):
"""Linear layer for the linear transformation of the query, key, and
value vectors in the cross-attention layer.

Q is computed from the previous decoder layer outputs;
KV are computed from the encoder output hidden states
during prefill; thus, `forward()` takes two tensor
arguments.

The weight matrix is concatenated along the output dimension. However,
Q and KV are computed in two steps, which operate respectively on W_Q
and [W_K W_V] (two views obtained by slicing `self.weight`.)

The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number
of query heads (e.g., multi-query/grouped-query attention), the key/value
head may be replicated while the query heads are partitioned.
"""

def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
'''
The :class:`QKVParallelLinear` parent class packs [W_Q W_K W_V]
and the corresponding bias vectors into `self.weight` and `self.bias`
(respectively) during weight loading.

However cross-attention QKV computation requires that we
(1) partially unpack `self.weights` into the W_Q weights matrix &
the packed [W_K W_V] weights matrix,
(2) partially unpack `self.bias` into the Q bias vector,
and the packed KV bias vector.

To avoid recomputing these views of the underlying `self.weight`
and `self.bias`, we cache them.
'''
self._param_views_not_cached: bool = True
self._cached_q_weights_wrapper: Optional[_WeightWrapper] = None
self._cached_kv_weights_wrapper: Optional[_WeightWrapper] = None
self._cached_q_bias: Optional[torch.Tensor] = None
self._cached_kv_bias: Optional[torch.Tensor] = None

super().__init__(hidden_size=hidden_size,
head_size=head_size,
total_num_heads=total_num_heads,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)

def _maybe_cache_param_views(self) -> None:
'''
Compute the W_Q weights, packed [W_K W_V] weights, Q bias vector, and
packed KV bias vector once, in order to cache them.
'''

if self._param_views_not_cached:
q_shard_begin_offset = self._get_shard_offset_mapping('q')
kv_shards_begin_offset = self._get_shard_offset_mapping('k')
self._cached_q_weights_wrapper = _WeightWrapper(
self.weight[q_shard_begin_offset:kv_shards_begin_offset, :])
self._cached_kv_weights_wrapper = _WeightWrapper(
self.weight[kv_shards_begin_offset:, :])
self._cached_q_bias = self.bias[
q_shard_begin_offset:kv_shards_begin_offset]
self._cached_kv_bias = self.bias[kv_shards_begin_offset:]
self._param_views_not_cached = False

def _maybe_gather_output(
self,
q_output_parallel: torch.Tensor,
kv_output_parallel: Optional[torch.Tensor],
is_decode_phase: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
'''
Perform all-gather, if required.

Arguments:

* `q_output_parallel`: computed Q on current GPU
* `kv_output_parallel`: computed KV on current GPU
* `is_decode_phase`: skip KV all-gather if True

Returns:

* Q all-gather result if required, otherwise `q_output_parallel`
* For KV:
* If all-gather required,
* KV all-gather result if in prefill-phase
* None if in decode-phase
* `kv_output_parallel` otherwise
'''

if self.gather_output:
# All-gather across the partitions.
return (
tensor_model_parallel_all_gather(q_output_parallel),
(None if is_decode_phase else
tensor_model_parallel_all_gather(kv_output_parallel)),
)

return (
q_output_parallel,
kv_output_parallel, # None if skip_cross_kvs
)

def _apply_w_conditional_bias(
self,
weights: _WeightWrapper,
input: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
'''
Invoke linear method, utilizing bias argument if required.

Arguments:

* `weights`: parameter matrix
* `input`: hidden states
* `bias`: bias

Returns:

* `weights` * `input` if `self.skip_bias_add`
* `weights` * `input` + `bias` otherwise
'''
assert self.quant_method is not None
return self.quant_method.apply(weights, input,
None if self.skip_bias_add else bias)

def forward(
self,
decoder_input_: torch.Tensor,
encoder_input_: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
'''
Arguments:

* `decoder_input_`: Q will be computed using these hidden states
* `encoder_input_`: KV is computed using these hidden states
* If `None`, KV computations are skipped (implicitly decode-phase)

Returns:

* Q = (`decoder_input_`) x W_Q
* KV = (`encoder_input_`) x [W_K, W_V]
* (`None` if `encoder_input_ `is `None`)
* Q bias vector
* KV bias vector (`None` if `encoder_input_` is `None`)
'''
self._maybe_cache_param_views()
assert self._cached_q_weights_wrapper is not None
assert self._cached_kv_weights_wrapper is not None
assert self._cached_q_bias is not None
assert self._cached_kv_bias is not None

# Compute Q and maybe KV
is_decode_phase = encoder_input_ is None
q_output_parallel = self._apply_w_conditional_bias(
self._cached_q_weights_wrapper, decoder_input_,
self._cached_q_bias)
kv_output_parallel = None if is_decode_phase else (
self._apply_w_conditional_bias(self._cached_kv_weights_wrapper,
encoder_input_,
self._cached_kv_bias))

# All-gather if needed
(
q_output,
kv_output,
) = self._maybe_gather_output(q_output_parallel, kv_output_parallel,
is_decode_phase)

return (
q_output,
kv_output,
self._cached_q_bias if self.skip_bias_add else None,
self._cached_kv_bias if self.skip_bias_add else None,
)


class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.

Expand Down
21 changes: 11 additions & 10 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QCrossKVParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -303,7 +304,7 @@ def __init__(
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5

self.qkv_proj = QKVParallelLinear(
self.qkv_proj = QCrossKVParallelLinear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was wondering if in the encoder-decode e2e test (e.g. test_bart.py) did you observe any increase in the input token/s or the output token/s with this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I ran a very informal test to prove to myself that QCrossKVParallelLinear improved the total runtime of the encoder/decoder example script. I should run an actual benchmark in order compare tokens/s

self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
Expand Down Expand Up @@ -351,18 +352,18 @@ def forward(
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""

# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
(
q,
kv,
_,
_,
) = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)

if kv is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)

attn_output = self.attn(q,
k,
Expand Down
Loading