From ad25d6520d16d6f6be273266c842080d56005e57 Mon Sep 17 00:00:00 2001 From: maktukmak <31421551+maktukmak@users.noreply.github.com> Date: Mon, 11 Dec 2023 00:28:27 -0800 Subject: [PATCH] CPU only build (#9) --- Makefile | 6 + cpu.Dockerfile | 77 ++ csrc/dispatch_utils.h | 4 + csrc/pybind.cpp | 6 + Dockerfile => gpu.Dockerfile | 14 +- pyproject.toml | 2 +- requirements-build-cpu.txt | 6 + ...ts-build.txt => requirements-build-gpu.txt | 0 requirements-cpu.txt | 15 + requirements.txt => requirements-gpu.txt | 0 setup.py | 95 +- vllm/model_executor/layers/activation.py | 2 +- vllm/model_executor/layers/attention.py | 23 +- .../model_executor/layers/rotary_embedding.py | 17 +- .../layers/xformers_cpu/__init__.py | 0 .../layers/xformers_cpu/attn_bias.py | 929 ++++++++++++++++++ vllm/utils.py | 2 +- 17 files changed, 1143 insertions(+), 55 deletions(-) create mode 100644 cpu.Dockerfile rename Dockerfile => gpu.Dockerfile (89%) create mode 100644 requirements-build-cpu.txt rename requirements-build.txt => requirements-build-gpu.txt (100%) create mode 100644 requirements-cpu.txt rename requirements.txt => requirements-gpu.txt (100%) create mode 100644 vllm/model_executor/layers/xformers_cpu/__init__.py create mode 100644 vllm/model_executor/layers/xformers_cpu/attn_bias.py diff --git a/Makefile b/Makefile index 2ecf3305c2193..35ae0731c050d 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,12 @@ sanitizer: py_install: VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e . +py_install_cpu: + VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e . + +install_vllm: + MAX_JOBS=JOBS pip install -v git+https://github.com/intel-sandbox/vllm-xpu.git@dev -f https://download.pytorch.org/whl/torch_stable.html + package: VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS python setup.py bdist_wheel echo "Wheel package is saved in ./dist/" diff --git a/cpu.Dockerfile b/cpu.Dockerfile new file mode 100644 index 0000000000000..15f9d5be6408a --- /dev/null +++ b/cpu.Dockerfile @@ -0,0 +1,77 @@ +FROM python:3.10 AS dev + +RUN apt-get update -y \ + && apt-get install -y python3-pip + +WORKDIR /workspace + +# install build and runtime dependencies +COPY requirements-cpu.txt requirements-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-cpu.txt + +# install development dependencies +COPY requirements-dev.txt requirements-dev.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-dev.txt + +# image to build pytorch extensions +FROM dev AS build + +# install build dependencies +COPY requirements-build-cpu.txt requirements-build-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-build-cpu.txt + +# copy input files +COPY csrc csrc +COPY setup.py setup.py +COPY requirements-cpu.txt requirements-cpu.txt +COPY pyproject.toml pyproject.toml +COPY vllm/__init__.py vllm/__init__.py + +# max jobs used by Ninja to build extensions +ENV MAX_JOBS=$max_jobs +RUN python3 setup.py build_ext --inplace + +# image to run unit testing suite +FROM dev AS test + +# copy pytorch extensions separately to avoid having to rebuild +# when python code changes +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY tests tests +COPY vllm vllm + +ENTRYPOINT ["python3", "-m", "pytest", "tests"] + +# use CUDA base as CUDA runtime dependencies are already installed via pip +FROM python:3.10 AS dev + +# libnccl required for ray +RUN apt-get update -y \ + && apt-get install -y python3-pip + +WORKDIR /workspace +COPY requirements-cpu.txt requirements-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-cpu.txt + +FROM vllm-base AS vllm +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY vllm vllm + +EXPOSE 8000 +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] + +# openai api server alternative +FROM vllm-base AS vllm-openai +# install additional dependencies for openai api server +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install accelerate fschat + +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY vllm vllm + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 956b2b88fc65e..c15cda7c39950 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,10 +14,14 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#ifdef VLLM_BUILD_CPU_ONLY +#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) +#else #define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) \ case c10::DeviceType::CUDA: { \ return BASENAME(__VA_ARGS__); \ } +#endif #ifdef VLLM_BUILD_CPU_OPS #define VLLM_DISPATCH_TO_CPU_CASE(BASENAME, ...) \ diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 688793168a91b..4d5279b2e5a1d 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -87,6 +87,12 @@ void gptq_shuffle_dispatch( VLLM_DISPATCH_DEVICES(q_weight.device(), gptq_shuffle, q_weight, q_perm); } +#ifdef VLLM_BUILD_CPU_ONLY +int get_device_attribute( + int attribute, + int device_id) { return 94387; } +#endif + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // vLLM custom ops pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); diff --git a/Dockerfile b/gpu.Dockerfile similarity index 89% rename from Dockerfile rename to gpu.Dockerfile index 44b1dd17d7e02..a480aac16f144 100644 --- a/Dockerfile +++ b/gpu.Dockerfile @@ -10,9 +10,9 @@ RUN apt-get update -y \ WORKDIR /workspace # install build and runtime dependencies -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-gpu.txt # install development dependencies COPY requirements-dev.txt requirements-dev.txt @@ -25,14 +25,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \ FROM dev AS build # install build dependencies -COPY requirements-build.txt requirements-build.txt +COPY requirements-build-gpu.txt requirements-build-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-build.txt + pip install -r requirements-build-gpu.txt # copy input files COPY csrc csrc COPY setup.py setup.py -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt COPY pyproject.toml pyproject.toml COPY vllm/__init__.py vllm/__init__.py @@ -75,9 +75,9 @@ RUN apt-get update -y \ && apt-get install -y python3-pip WORKDIR /workspace -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-gpu.txt #################### RUNTIME BASE IMAGE #################### diff --git a/pyproject.toml b/pyproject.toml index b197256f6ff55..54bec2e47013b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.1.2", + "torch == 2.1.2+cpu", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build-cpu.txt b/requirements-build-cpu.txt new file mode 100644 index 0000000000000..04312d11f61b9 --- /dev/null +++ b/requirements-build-cpu.txt @@ -0,0 +1,6 @@ +# Should be mirrored in pyproject.toml +ninja +packaging +setuptools>=49.4.0 +torch==2.1.2+cpu +wheel \ No newline at end of file diff --git a/requirements-build.txt b/requirements-build-gpu.txt similarity index 100% rename from requirements-build.txt rename to requirements-build-gpu.txt diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000..6f3ff1cce61fe --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,15 @@ +ninja # For faster builds. +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +pybind11 +sentencepiece # Required for LLaMA tokenizer. +numpy +einops # Required for phi-1_5 +torch == 2.1.2+cpu +transformers >= 4.34.0 # Required for Mistral. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. +aioprometheus[starlette] diff --git a/requirements.txt b/requirements-gpu.txt similarity index 100% rename from requirements.txt rename to requirements-gpu.txt diff --git a/setup.py b/setup.py index ac4b47961d4af..aff6e86b235ec 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,13 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME + +BUILD_CPU_ONLY = os.getenv('VLLM_BUILD_CPU_ONLY', "1") == "1" + +if not BUILD_CPU_ONLY: + from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME +else: + from torch.utils.cpp_extension import BuildExtension, CppExtension ROOT_DIR = os.path.dirname(__file__) @@ -21,11 +27,11 @@ def _is_hip() -> bool: - return torch.version.hip is not None + return torch.version.hip is not None and not BUILD_CPU_ONLY def _is_cuda() -> bool: - return torch.version.cuda is not None + return torch.version.cuda is not None and not BUILD_CPU_ONLY # Compiler flags. @@ -86,7 +92,6 @@ def get_hipcc_rocm_version(): print("Could not find HIP version in the output") return None - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -137,6 +142,19 @@ def get_torch_arch_list() -> Set[str]: stacklevel=2) return arch_list +if not BUILD_CPU_ONLY: + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() @@ -211,9 +229,11 @@ def get_torch_arch_list() -> Set[str]: f"amdgpu_arch_found: {amd_arch}") # Setup CPU Operations -BUILD_CPU_OPS = os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1" +BUILD_CPU_OPS = (os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1" or BUILD_CPU_ONLY) CPU_OPS_SOURCES = [] if BUILD_CPU_OPS: + if BUILD_CPU_ONLY: + CXX_FLAGS += ["-DVLLM_BUILD_CPU_ONLY"] CXX_FLAGS += [ "-DVLLM_BUILD_CPU_OPS", "-fopenmp", "-mavx512f", "-mavx512bf16", "-mavx512vl" @@ -228,29 +248,42 @@ def get_torch_arch_list() -> Set[str]: ext_modules = [] -vllm_extension_sources = [ - "csrc/cache_kernels.cu", - "csrc/attention/attention_kernels.cu", - "csrc/pos_encoding_kernels.cu", - "csrc/activation_kernels.cu", - "csrc/layernorm_kernels.cu", - "csrc/quantization/squeezellm/quant_cuda_kernel.cu", - "csrc/quantization/gptq/q_gemm.cu", - "csrc/cuda_utils_kernels.cu", - "csrc/pybind.cpp", -] + CPU_OPS_SOURCES +if not BUILD_CPU_ONLY: + vllm_extension_sources = [ + "csrc/cache_kernels.cu", + "csrc/attention/attention_kernels.cu", + "csrc/pos_encoding_kernels.cu", + "csrc/activation_kernels.cu", + "csrc/layernorm_kernels.cu", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", + "csrc/cuda_utils_kernels.cu", + "csrc/pybind.cpp", + ] + CPU_OPS_SOURCES + + if _is_cuda(): + vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + + vllm_extension = CUDAExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) +else: + vllm_extension_sources = [ + "csrc/pybind.cpp", + ] + CPU_OPS_SOURCES + vllm_extension = CppExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + }, + ) -if _is_cuda(): - vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - -vllm_extension = CUDAExtension( - name="vllm._C", - sources=vllm_extension_sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) ext_modules.append(vllm_extension) @@ -280,7 +313,7 @@ def get_vllm_version() -> str: if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] version += f"+rocm{rocm_version_str}" - else: + elif _is_cuda(): cuda_version = str(nvcc_cuda_version) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] @@ -303,9 +336,13 @@ def get_requirements() -> List[str]: if _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_cuda(): + with open(get_path("requirements-gpu.txt")) as f: + requirements = f.read().strip().split("\n") else: - with open(get_path("requirements.txt")) as f: + with open(get_path("requirements-cpu.txt")) as f: requirements = f.read().strip().split("\n") + return requirements diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1af120d13cd4b..3d4ffe1e5c263 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -27,7 +27,7 @@ class SiluAndMul(nn.Module): def _forward(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] + return (F.silu(x[..., :d].float()) * x[..., d:].float()).to(x) def forward(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 28ecbd968a251..54821d0a2a2a7 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -3,9 +3,17 @@ import torch import torch.nn as nn -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) +try: + from xformers import ops as xops +except: + pass + +try: + from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) +except: + from vllm.model_executor.layers.xformers_cpu.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) from vllm._C import ops from vllm._C import cache_ops @@ -160,12 +168,9 @@ def forward( op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (is_hip()) else None, ) if not self.cpu_only else torch.nn.functional.scaled_dot_product_attention( - query.movedim(1, - query.dim() - - 2), key.movedim(1, - query.dim() - 2), - value.movedim(1, - value.dim() - 2), input_metadata.attn_bias, + query.movedim(1, query.dim() -2), key.movedim(1, query.dim() - 2), + value.movedim(1, value.dim() - 2), + input_metadata.attn_bias, 0.0).movedim(query.dim() - 2, 1).contiguous() output = out.view_as(query) else: diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index abeaa6abff8ff..3ee7e054accd3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -98,16 +98,19 @@ def _forward( key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" + device = query.device + dtype = query.dtype + query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + query_rot = query[..., :self.rotary_dim].float() + key_rot = key[..., :self.rotary_dim].float() if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim:].float() + key_pass = key[..., self.rotary_dim:].float() - cos_sin = self.cos_sin_cache[positions] + cos_sin = self.cos_sin_cache[positions].float() cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the @@ -128,8 +131,8 @@ def _forward( else: query = query_rot key = key_rot - query = query.flatten(-2) - key = key.flatten(-2) + query = query.flatten(-2).to(dtype=dtype, device=device) + key = key.flatten(-2).to(dtype=dtype, device=device) return query, key def forward( diff --git a/vllm/model_executor/layers/xformers_cpu/__init__.py b/vllm/model_executor/layers/xformers_cpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/xformers_cpu/attn_bias.py b/vllm/model_executor/layers/xformers_cpu/attn_bias.py new file mode 100644 index 0000000000000..89b9d53469c87 --- /dev/null +++ b/vllm/model_executor/layers/xformers_cpu/attn_bias.py @@ -0,0 +1,929 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +def _materialize_causal_mask( + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + *, + window_size: Optional[int] = None, + from_bottomright: bool = False, +) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = 0 + if from_bottomright: + shift = num_keys - num_queries + + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class LocalAttentionFromBottomRightMask(AttentionBias): + """ + A local attention mask + + The query at position :math:`q` can attend the key at position :math:`k` if + :math:`q - window\\_left <= k + s <= q + window\\_right` + + With :math:`s = num\\_queries - num\\_keys` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) + print(bias.materialize(shape=(4, 4)).exp()) + print(bias.materialize(shape=(4, 5)).exp()) + + .. code-block:: text + + # 4x4 + tensor([[1., 1., 1., 0.], + [1., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 0., 1., 1.]]) + + # 4x5 + tensor([[1., 1., 1., 1., 0.], + [0., 1., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]]) + + :Illustration: + + .. figure:: /_static/local_attn.png + :width: 240px + + The total window size is :math:`window\\_left + 1 + window\\_right` + """ + + window_left: int + window_right: int + + def __post_init__(self) -> None: + if self.window_left < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_left > 0` but got window_left={self.window_left}" + ) + if self.window_right < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_right > 0` but got window_right={self.window_right}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + mask = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = num_keys - num_queries + + mask = torch.triu(mask, diagonal=shift - self.window_left) + mask = torch.tril(mask, diagonal=shift + self.window_right) + mask = torch.log(mask) + return mask.to(dtype) + + +class LowerTriangularMask(AttentionBias): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + + See also :attr:`LowerTriangularFromBottomRightMask` if the number + of queries is not equal to the number of keys/values. + """ + + def __init__(self, *tensor_args, **tensor_kwargs) -> None: + # NOTE: Unused arguments, we keep them for backward compatibility + super().__init__() + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask(shape, dtype=dtype, device=device) + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + """ + Creates a new causal mask with an arbitrary ``torch.Tensor`` bias + """ + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularFromBottomRightMask(AttentionBias): + """ + A causal masking. + + This mask is exactly the same as :attr:`LowerTriangularMask` when there is + the same number of queries and keys. + When the number of queries is different from the number of keys, + it is a triangular mask shifted so that the last query can attend to + the last key. + In other words, a query Q cannot attend to a key which is nearer the + final key than Q is to the final query. + + + .. figure:: /_static/causal_bottom_right.png + + The difference between :attr:`LowerTriangularMask` (left) and + :attr:`LowerTriangularFromBottomRightMask` (right). They become + equivalent if the number of queries equals the number of keys. + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, dtype=dtype, device=device, from_bottomright=True + ) + + def make_local_attention( + self, window_size: int + ) -> "LowerTriangularFromBottomRightLocalAttentionMask": + """ + Create a new bias which combines local + causal attention. + + See :attr:`LowerTriangularFromBottomRightLocalAttentionMask` + """ + return LowerTriangularFromBottomRightLocalAttentionMask(window_size) + + +@dataclass +class LowerTriangularFromBottomRightLocalAttentionMask( + LowerTriangularFromBottomRightMask +): + """ + A mask that combines both :attr:`LowerTriangularFromBottomRightMask` and + local attention. + + A query whose distance from the final query is X cannot attend to a key + whose distance to the final key is either of: + + * less than X (i.e. "causal attention", same as :attr:`LowerTriangularFromBottomRightMask`) + * greater than X + window_size (i.e. "local attention") + + + .. figure:: /_static/causal_bottom_right_local.png + + The mask from :attr:`LowerTriangularFromBottomRightLocalAttentionMask`. + The green area is calculated, and the grey area is masked out. + """ + + _window_size: int + + def __post_init__(self) -> None: + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + def __init__(self, bias: torch.Tensor) -> None: + self._bias = bias + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._bias + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> None: + self.seqlen = self.seqlen.to(device, non_blocking=True) + super().to(device) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, seqlens: Sequence[int], padding: int + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all(seqlen <= padding for seqlen in seqlens) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + # Each query only attends to keys no further than window_size back. + # When q > k + window_size, there will be a query for which the window doesn't reach any key. + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + ) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) \ No newline at end of file diff --git a/vllm/utils.py b/vllm/utils.py index 4d82f92129c95..105783913f513 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,10 +7,10 @@ import psutil import torch +import os from vllm._C import cuda_utils - class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto()