Skip to content

Commit

Permalink
Raise early errors if submodules not available (facebookresearch#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored Aug 26, 2022
1 parent 24bd176 commit bfbd373
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,14 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
DEFAULT_ARCHS_LIST = ""
if cuda_version > 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
elif cuda_version >= 1100:
elif cuda_version == 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0"
else:
return []

if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0":
return []

archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST)
nvcc_archs_flags = []
for arch in archs_list.split(";"):
Expand All @@ -86,6 +89,12 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

this_dir = os.path.dirname(os.path.abspath(__file__))
flash_root = os.path.join(this_dir, "third_party", "flash-attention")
if not os.path.exists(flash_root):
raise RuntimeError(
"flashattention submodule not found. Did you forget "
"to run `git submodule update --init --recursive` ?"
)

return [
CUDAExtension(
name="xformers._C_flashattention",
Expand Down Expand Up @@ -144,6 +153,11 @@ def get_extensions():

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
"CUTLASS submodule not found. Did you forget "
"to run `git submodule update --init --recursive` ?"
)

extension = CppExtension

Expand Down Expand Up @@ -174,13 +188,10 @@ def get_extensions():
if cuda_version >= 1102:
nvcc_flags += ["--threads", "4", "--ptxas-options=-v"]
extra_compile_args["nvcc"] = nvcc_flags
if (
cuda_version >= 1100
and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0"
):
ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

sources = [os.path.join(extensions_dir, s) for s in sources]

Expand Down

0 comments on commit bfbd373

Please sign in to comment.