Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into ray/lora-gptbigcode-implem…
Browse files Browse the repository at this point in the history
…entation
  • Loading branch information
raywanb committed May 22, 2024
2 parents 53100c3 + 99eff67 commit 97e2549
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 10 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,31 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.

## Sponsors

vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!

<!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->

- a16z
- AMD
- Anyscale
- AWS
- Crusoe Cloud
- Databricks
- DeepInfra
- Lambda Lab
- NVIDIA
- Replicate
- Roblox
- RunPod
- Trainy
- UC Berkeley
- UC San Diego

We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.

## Citation

If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
Expand Down
24 changes: 24 additions & 0 deletions docs/source/community/sponsors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Sponsors

vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!

<!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with README.md. -->

- a16z
- AMD
- Anyscale
- AWS
- Crusoe Cloud
- Databricks
- DeepInfra
- Lambda Lab
- NVIDIA
- Replicate
- Roblox
- RunPod
- Trainy
- UC Berkeley
- UC San Diego

We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Documentation
:caption: Community

community/meetups
community/sponsors

Indices and tables
==================
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"

[tool.isort]
use_parentheses = true
Expand Down
12 changes: 8 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)

_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]


class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "flash-attn"
Expand Down Expand Up @@ -237,10 +239,12 @@ def __init__(
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
if head_size not in _SUPPORTED_HEAD_SIZES:

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
f"Supported head sizes are: {support_head_sizes}.")

def forward(
self,
Expand Down
16 changes: 13 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,21 @@ def get_attn_backend(
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:

# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size in supported_head_sizes:
logger.info("Using FlashAttention-2 backend.")
return FlashAttentionBackend
logger.info(
"Cannot use FlashAttention-2 backend for head size %d. "
"Using XFormers backend instead.", head_size)
backend = _Backend.XFORMERS

if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ def add_cli_args(
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
'(DEPRECATED. Use --max-seq-len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
parser.add_argument('--max-seq-len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
Expand Down

0 comments on commit 97e2549

Please sign in to comment.