|
6 | 6 | import glob
|
7 | 7 | import os
|
8 | 8 | import subprocess
|
| 9 | +import sys |
| 10 | +import time |
9 | 11 | from datetime import datetime
|
10 | 12 |
|
11 | 13 | from setuptools import find_packages, setup
|
@@ -58,6 +60,71 @@ def read_version(file_path="version.txt"):
|
58 | 60 | CUDAExtension,
|
59 | 61 | )
|
60 | 62 |
|
| 63 | +# Constant known variables used throughout this file |
| 64 | +cwd = os.path.abspath(os.path.curdir) |
| 65 | +third_party_path = os.path.join(cwd, "third_party") |
| 66 | + |
| 67 | + |
| 68 | +def get_submodule_folders(): |
| 69 | + git_modules_path = os.path.join(cwd, ".gitmodules") |
| 70 | + default_modules_path = [ |
| 71 | + os.path.join(third_party_path, name) |
| 72 | + for name in [ |
| 73 | + "cutlass", |
| 74 | + ] |
| 75 | + ] |
| 76 | + if not os.path.exists(git_modules_path): |
| 77 | + return default_modules_path |
| 78 | + with open(git_modules_path) as f: |
| 79 | + return [ |
| 80 | + os.path.join(cwd, line.split("=", 1)[1].strip()) |
| 81 | + for line in f |
| 82 | + if line.strip().startswith("path") |
| 83 | + ] |
| 84 | + |
| 85 | + |
| 86 | +def check_submodules(): |
| 87 | + def check_for_files(folder, files): |
| 88 | + if not any(os.path.exists(os.path.join(folder, f)) for f in files): |
| 89 | + print("Could not find any of {} in {}".format(", ".join(files), folder)) |
| 90 | + print("Did you run 'git submodule update --init --recursive'?") |
| 91 | + sys.exit(1) |
| 92 | + |
| 93 | + def not_exists_or_empty(folder): |
| 94 | + return not os.path.exists(folder) or ( |
| 95 | + os.path.isdir(folder) and len(os.listdir(folder)) == 0 |
| 96 | + ) |
| 97 | + |
| 98 | + if bool(os.getenv("USE_SYSTEM_LIBS", False)): |
| 99 | + return |
| 100 | + folders = get_submodule_folders() |
| 101 | + # If none of the submodule folders exists, try to initialize them |
| 102 | + if all(not_exists_or_empty(folder) for folder in folders): |
| 103 | + try: |
| 104 | + print(" --- Trying to initialize submodules") |
| 105 | + start = time.time() |
| 106 | + subprocess.check_call( |
| 107 | + ["git", "submodule", "update", "--init", "--recursive"], cwd=cwd |
| 108 | + ) |
| 109 | + end = time.time() |
| 110 | + print(f" --- Submodule initialization took {end - start:.2f} sec") |
| 111 | + except Exception: |
| 112 | + print(" --- Submodule initalization failed") |
| 113 | + print("Please run:\n\tgit submodule update --init --recursive") |
| 114 | + sys.exit(1) |
| 115 | + for folder in folders: |
| 116 | + check_for_files( |
| 117 | + folder, |
| 118 | + [ |
| 119 | + "CMakeLists.txt", |
| 120 | + "Makefile", |
| 121 | + "setup.py", |
| 122 | + "LICENSE", |
| 123 | + "LICENSE.md", |
| 124 | + "LICENSE.txt", |
| 125 | + ], |
| 126 | + ) |
| 127 | + |
61 | 128 |
|
62 | 129 | def get_extensions():
|
63 | 130 | debug_mode = os.getenv("DEBUG", "0") == "1"
|
@@ -106,8 +173,7 @@ def get_extensions():
|
106 | 173 | use_cutlass = False
|
107 | 174 | if use_cuda and not IS_WINDOWS:
|
108 | 175 | use_cutlass = True
|
109 |
| - this_dir = os.path.abspath(os.path.curdir) |
110 |
| - cutlass_dir = os.path.join(this_dir, "third_party", "cutlass") |
| 176 | + cutlass_dir = os.path.join(third_party_path, "cutlass") |
111 | 177 | cutlass_include_dir = os.path.join(cutlass_dir, "include")
|
112 | 178 | if use_cutlass:
|
113 | 179 | extra_compile_args["nvcc"].extend(
|
@@ -145,6 +211,8 @@ def get_extensions():
|
145 | 211 | return ext_modules
|
146 | 212 |
|
147 | 213 |
|
| 214 | +check_submodules() |
| 215 | + |
148 | 216 | setup(
|
149 | 217 | name="torchao",
|
150 | 218 | version=version + version_suffix,
|
|
0 commit comments