Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash-attn #26239

Merged
merged 23 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
07ec11e
Add flash-attn
weiji14 May 4, 2024
a4def75
Add cuda compiler
weiji14 May 4, 2024
06713e3
Add c and cxx compilers
weiji14 May 4, 2024
51d7d74
Add stdlib c
weiji14 May 4, 2024
aa17a2c
Skip build on non-cuda platforms
weiji14 May 4, 2024
4d8b37c
Add libcublas-dev, libcusolver-dev, libcusparse-dev to host deps
weiji14 May 4, 2024
04a346a
Remove noarch: python
weiji14 May 4, 2024
16414ff
Remove minimum pin on python version
weiji14 May 4, 2024
2d2212d
BLD: Fixup build environment and variables
carterbox May 4, 2024
2cde3c1
Sort dependencies alphabetically
weiji14 May 5, 2024
501aa9d
Set azure timeout_minutes to 360 in conda-forge.yml
weiji14 May 6, 2024
a1b1faa
Set TORCH_CUDA_ARCH_LIST to 8.0 and above
weiji14 May 6, 2024
5235314
Drop quotes from `script_env`
weiji14 May 6, 2024
ef03f90
DEV: Reduce archs and jobs
carterbox May 6, 2024
ebae578
Add missing host deps and rpaths
carterbox May 6, 2024
460eeb2
Merge remote-tracking branch 'forge/main' into add-flash-attn
carterbox May 6, 2024
317646a
BLD: Replace setup script with simpler one
carterbox May 6, 2024
0b81f6f
Update recipes/flash-attn/meta.yaml
carterbox May 6, 2024
96e817a
Add license for CUTLASS
weiji14 May 6, 2024
0733767
Add libcublas-dev, libcusolver-dev, libcusparse-dev to host deps again
weiji14 May 6, 2024
37e676a
ignore_run_exports_from libcublas-dev, libcusolver-dev, libcusparse-dev
weiji14 May 7, 2024
fc2fc76
Temporarily set TORCH_CUDA_ARCH_LIST=8.6+PTX and MAX_JOBS=1
weiji14 May 7, 2024
63dcb65
BLD: Lower CUDA arch target to 8.0
carterbox May 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions recipes/flash-attn/conda-forge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
azure:
timeout_minutes: 360
Comment on lines +1 to +2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦🏼 You're right. But also, it looks like the timeout is already set to 360 minutes in staged-recipes. So probably, the builds are failing for other reasons. Perhaps, the worker crashes by running out of RAM or disk space? Let's try reducing the compute load as much as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this file because we need to have it in the feedstock.

66 changes: 66 additions & 0 deletions recipes/flash-attn/meta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{% set name = "flash-attn" %}
{% set version = "2.5.8" %}

package:
name: {{ name|lower }}
version: {{ version }}

source:
- url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/flash_attn-{{ version }}.tar.gz
sha256: 2e5b2bcff6d5cff40d494af91ecd1eb3c5b4520a6ce7a0a8b1f9c1ed129fb402
# Overwrite with a simpler build script that doesn't try to revend pre-compiled binaries
- path: pyproject.toml
- path: setup.py

build:
number: 0
script: {{ PYTHON }} -m pip install . -vvv --no-deps --no-build-isolation
script_env:
- MAX_JOBS=$CPU_COUNT
- TORCH_CUDA_ARCH_LIST=8.0;8.6;8.9;9.0+PTX
carterbox marked this conversation as resolved.
Show resolved Hide resolved
skip: true # [cuda_compiler_version in (undefined, "None")]
skip: true # [not linux]
rpaths:
- lib/
# PyTorch libs are in site-packages instead of with other shared objects
- {{ SP_DIR }}/torch/lib/

requirements:
build:
- {{ compiler('c') }}
- {{ compiler('cxx') }}
- {{ compiler('cuda') }}
- {{ stdlib('c') }}
- ninja
host:
- cuda-version {{ cuda_compiler_version }} # same cuda for host and build
- cuda-cudart-dev # [(cuda_compiler_version or "").startswith("12")]
- libtorch # required until pytorch run_exports libtorch
- pip
- python
- pytorch
- pytorch =*=cuda*
- setuptools
run:
- einops
- python
- pytorch =*=cuda*

test:
imports:
- flash_attn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test may need to be commented out because the test runners don't have a GPU, so imports might fail.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports seemed to have worked at https://dev.azure.com/conda-forge/feedstock-builds/_build/results?buildId=928722&view=logs&j=4f860608-e5f8-5c9c-4eb0-308a99ecb61e&t=02ef1a5c-d960-5c54-fcea-983775f057bb&l=1352

done
export PREFIX=/home/conda/staged-recipes/build_artifacts/flash-attn_1715047997366/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_place
export SRC_DIR=/home/conda/staged-recipes/build_artifacts/flash-attn_1715047997366/test_tmp
import: 'flash_attn'
import: 'flash_attn'
+ pip check
No broken requirements found.
+ exit 0

commands:
- pip check
requires:
- pip
carterbox marked this conversation as resolved.
Show resolved Hide resolved

about:
home: https://github.com/Dao-AILab/flash-attention
summary: 'Flash Attention: Fast and Memory-Efficient Exact Attention'
license: BSD-3-Clause
license_file: LICENSE
weiji14 marked this conversation as resolved.
Show resolved Hide resolved

extra:
recipe-maintainers:
- carterbox
- weiji14
carterbox marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 28 additions & 0 deletions recipes/flash-attn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[build-system]
requires = ["setuptools>=62", "torch", "ninja"]
build-backend = "setuptools.build_meta"

[project]
dynamic = ["version"]
name = "flash_attn"
authors = [
{name = "Tri Dao", email = "trid@cs.stanford.edu"},
]
description="Flash Attention: Fast and Memory-Efficient Exact Attention"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: Unix",
"License :: OSI Approved :: BSD License",
]
readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
"torch",
"einops",
]

[project.urls]
Homepage = "https://github.com/Dao-AILab/flash-attention"

[tool.setuptools.dynamic]
version = {attr = "flash_attn.__version__"}
109 changes: 109 additions & 0 deletions recipes/flash-attn/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Conda-forge Contributors.

"""Since this package is a pytorch extension, this setup file uses the custom
CUDAExtension build system from pytorch. This ensures that compatible compiler
args, headers, etc for pytorch.

Read more at the pytorch docs:
https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension
"""

import pathlib

from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

_this_dir = pathlib.Path(__file__).parent.absolute()

setup(
packages=find_packages(
include=["flash_attn*"],
),
ext_modules=[
CUDAExtension(
name="flash_attn_2_cuda",
sources=[
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": [
"-std=c++17",
],
"nvcc": [
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
# "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
],
},
include_dirs=[
_this_dir / "csrc" / "flash_attn",
_this_dir / "csrc" / "flash_attn" / "src",
_this_dir / "csrc" / "cutlass" / "include",
],
),
],
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
)
Loading