-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash-attn #26239
Add flash-attn #26239
Changes from 18 commits
07ec11e
a4def75
06713e3
51d7d74
aa17a2c
4d8b37c
04a346a
16414ff
2d2212d
2cde3c1
501aa9d
a1b1faa
5235314
ef03f90
ebae578
460eeb2
317646a
0b81f6f
96e817a
0733767
37e676a
fc2fc76
63dcb65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
azure: | ||
timeout_minutes: 360 | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
{% set name = "flash-attn" %} | ||
{% set version = "2.5.8" %} | ||
|
||
package: | ||
name: {{ name|lower }} | ||
version: {{ version }} | ||
|
||
source: | ||
- url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/flash_attn-{{ version }}.tar.gz | ||
sha256: 2e5b2bcff6d5cff40d494af91ecd1eb3c5b4520a6ce7a0a8b1f9c1ed129fb402 | ||
# Overwrite with a simpler build script that doesn't try to revend pre-compiled binaries | ||
- path: pyproject.toml | ||
- path: setup.py | ||
|
||
build: | ||
number: 0 | ||
script: {{ PYTHON }} -m pip install . -vvv --no-deps --no-build-isolation | ||
script_env: | ||
- MAX_JOBS=$CPU_COUNT | ||
- TORCH_CUDA_ARCH_LIST=8.0;8.6;8.9;9.0+PTX | ||
carterbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
skip: true # [cuda_compiler_version in (undefined, "None")] | ||
skip: true # [not linux] | ||
rpaths: | ||
- lib/ | ||
# PyTorch libs are in site-packages instead of with other shared objects | ||
- {{ SP_DIR }}/torch/lib/ | ||
|
||
requirements: | ||
build: | ||
- {{ compiler('c') }} | ||
- {{ compiler('cxx') }} | ||
- {{ compiler('cuda') }} | ||
- {{ stdlib('c') }} | ||
- ninja | ||
host: | ||
- cuda-version {{ cuda_compiler_version }} # same cuda for host and build | ||
- cuda-cudart-dev # [(cuda_compiler_version or "").startswith("12")] | ||
- libtorch # required until pytorch run_exports libtorch | ||
- pip | ||
- python | ||
- pytorch | ||
- pytorch =*=cuda* | ||
- setuptools | ||
run: | ||
- einops | ||
- python | ||
- pytorch =*=cuda* | ||
|
||
test: | ||
imports: | ||
- flash_attn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test may need to be commented out because the test runners don't have a GPU, so imports might fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports seemed to have worked at https://dev.azure.com/conda-forge/feedstock-builds/_build/results?buildId=928722&view=logs&j=4f860608-e5f8-5c9c-4eb0-308a99ecb61e&t=02ef1a5c-d960-5c54-fcea-983775f057bb&l=1352 done
export PREFIX=/home/conda/staged-recipes/build_artifacts/flash-attn_1715047997366/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_place
export SRC_DIR=/home/conda/staged-recipes/build_artifacts/flash-attn_1715047997366/test_tmp
import: 'flash_attn'
import: 'flash_attn'
+ pip check
No broken requirements found.
+ exit 0 |
||
commands: | ||
- pip check | ||
requires: | ||
- pip | ||
carterbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
about: | ||
home: https://github.com/Dao-AILab/flash-attention | ||
summary: 'Flash Attention: Fast and Memory-Efficient Exact Attention' | ||
license: BSD-3-Clause | ||
license_file: LICENSE | ||
weiji14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
extra: | ||
recipe-maintainers: | ||
- carterbox | ||
- weiji14 | ||
carterbox marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[build-system] | ||
requires = ["setuptools>=62", "torch", "ninja"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
dynamic = ["version"] | ||
name = "flash_attn" | ||
authors = [ | ||
{name = "Tri Dao", email = "trid@cs.stanford.edu"}, | ||
] | ||
description="Flash Attention: Fast and Memory-Efficient Exact Attention" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"Operating System :: Unix", | ||
"License :: OSI Approved :: BSD License", | ||
] | ||
readme = "README.md" | ||
license = {file = "LICENSE"} | ||
dependencies = [ | ||
"torch", | ||
"einops", | ||
] | ||
|
||
[project.urls] | ||
Homepage = "https://github.com/Dao-AILab/flash-attention" | ||
|
||
[tool.setuptools.dynamic] | ||
version = {attr = "flash_attn.__version__"} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2023, Tri Dao. | ||
# Copyright (c) 2024, Conda-forge Contributors. | ||
|
||
"""Since this package is a pytorch extension, this setup file uses the custom | ||
CUDAExtension build system from pytorch. This ensures that compatible compiler | ||
args, headers, etc for pytorch. | ||
|
||
Read more at the pytorch docs: | ||
https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension | ||
""" | ||
|
||
import pathlib | ||
|
||
from setuptools import setup, find_packages | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
||
_this_dir = pathlib.Path(__file__).parent.absolute() | ||
|
||
setup( | ||
packages=find_packages( | ||
include=["flash_attn*"], | ||
), | ||
ext_modules=[ | ||
CUDAExtension( | ||
name="flash_attn_2_cuda", | ||
sources=[ | ||
"csrc/flash_attn/flash_api.cpp", | ||
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", | ||
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", | ||
], | ||
extra_compile_args={ | ||
"cxx": [ | ||
"-std=c++17", | ||
], | ||
"nvcc": [ | ||
"-std=c++17", | ||
"-U__CUDA_NO_HALF_OPERATORS__", | ||
"-U__CUDA_NO_HALF_CONVERSIONS__", | ||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
"--expt-relaxed-constexpr", | ||
"--expt-extended-lambda", | ||
"--use_fast_math", | ||
# "--ptxas-options=-v", | ||
# "--ptxas-options=-O2", | ||
# "-lineinfo", | ||
# "-DFLASHATTENTION_DISABLE_BACKWARD", | ||
# "-DFLASHATTENTION_DISABLE_DROPOUT", | ||
# "-DFLASHATTENTION_DISABLE_ALIBI", | ||
# "-DFLASHATTENTION_DISABLE_UNEVEN_K", | ||
# "-DFLASHATTENTION_DISABLE_LOCAL", | ||
], | ||
}, | ||
include_dirs=[ | ||
_this_dir / "csrc" / "flash_attn", | ||
_this_dir / "csrc" / "flash_attn" / "src", | ||
_this_dir / "csrc" / "cutlass" / "include", | ||
], | ||
), | ||
], | ||
cmdclass={"build_ext": BuildExtension}, | ||
zip_safe=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it seems that the azure builds still timeout after 30-40min (e.g. at https://dev.azure.com/conda-forge/feedstock-builds/_build/results?buildId=927858&view=logs&jobId=67448ffb-e003-5bfa-c062-cee3af60fcba&j=67448ffb-e003-5bfa-c062-cee3af60fcba&t=818ff20d-11b7-59db-6ce1-bb4df921454a). Maybe this only works on the feedstock repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦🏼 You're right. But also, it looks like the timeout is already set to 360 minutes in staged-recipes. So probably, the builds are failing for other reasons. Perhaps, the worker crashes by running out of RAM or disk space? Let's try reducing the compute load as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this file because we need to have it in the feedstock.