Skip to content

Commit

Permalink
Add alibi support (#69)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
  • Loading branch information
wenbinc-Bin authored Jul 1, 2024
1 parent 20eafe9 commit aae39b1
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 64 deletions.
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
max_seq_len : Optional[int] = 4096,
) -> None:
raise NotImplementedError

Expand Down
74 changes: 39 additions & 35 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,21 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
max_seq_len : Optional[int] = 4096,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.position_bias = None
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.bfloat16)
self.position_bias = _make_alibi_bias(alibi_slopes,
num_kv_heads,
alibi_slopes.dtype,
max_seq_len)
self.alibi_slopes = alibi_slopes

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand Down Expand Up @@ -199,13 +204,17 @@ def forward(
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# TODO: move this outside of model
assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!'
attn_bias = prefill_meta.attn_bias
if self.alibi_slopes is not None:
attn_bias.add_(self.position_bias[:, :, -attn_bias.size(2):, -attn_bias.size(3):])

query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
out = xops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=prefill_meta.attn_bias,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
)
Expand Down Expand Up @@ -236,10 +245,9 @@ def forward(
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
self.position_bias,
kv_scale
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand All @@ -248,33 +256,29 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

return attn_biases
seq_len: int,
) -> torch.Tensor:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
max_seq_len: Optional[int] = 4096,
) -> None:
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
alibi_slopes, sliding_window, max_seq_len)

def forward(
self,
Expand Down
7 changes: 5 additions & 2 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fetch_from_cache(cache, blocks, permutations):


@hpu_utils.with_mark_steps
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None:
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes=None, kv_cache_dtype=None) -> None:
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, _, kv_heads, _ = key_cache.shape
Expand All @@ -55,7 +55,10 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block
mask = mask.unsqueeze(2)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = (torch.cat(attn_weights, dim=-1)
attn_weights = torch.cat(attn_weights, dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:,:,-attn_weights.size(2):, -attn_weights.size(3):])
attn_weights = (attn_weights
.masked_fill(mask, min_inf)
.softmax(dim=-1))

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __init__(
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
max_seq_len=config.max_seq_len)

def forward(
self,
Expand Down
43 changes: 18 additions & 25 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,31 +115,24 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype):
prefill_metadata = attn_metadata.prefill_metadata
if prefill_metadata is None:
return attn_metadata
#FIXME: Restore alibi support
#if self.alibi_slopes is None:
if True:
seq_lens_t = prefill_metadata.seq_lens_tensor
len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32)
.view(1, seq_len)
.ge(seq_lens_t.unsqueeze(-1))
.view(batch_size, 1, 1, seq_len))
causal_mask = torch.triu(
torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool),
diagonal=1
)
mask = causal_mask.logical_or(len_mask)
attn_bias = (torch.zeros_like(mask, dtype=dtype)
.masked_fill_(mask, -math.inf))
#FIXME: Restore sliding window support
#if self.sliding_window is not None:
prefill_metadata = prefill_metadata._replace(attn_bias=attn_bias)
attn_metadata = attn_metadata._replace(prefill_metadata=prefill_metadata)
return attn_metadata
else:
# FIXME: This needs updating...
prefill_meta.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

seq_lens_t = prefill_metadata.seq_lens_tensor
len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32)
.view(1, seq_len)
.ge(seq_lens_t.unsqueeze(-1))
.view(batch_size, 1, 1, seq_len))
causal_mask = torch.triu(
torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool),
diagonal=1
)
mask = causal_mask.logical_or(len_mask)
attn_bias = (torch.zeros_like(mask, dtype=dtype)
.masked_fill_(mask, -math.inf))
#FIXME: Restore sliding window support
#if self.sliding_window is not None:
prefill_metadata = prefill_metadata._replace(attn_bias=attn_bias)
attn_metadata = attn_metadata._replace(prefill_metadata=prefill_metadata)
return attn_metadata


def forward(self, *args, **kwargs):
Expand Down

0 comments on commit aae39b1

Please sign in to comment.