Skip to content

Commit b7ca9b5

Browse files
committed
Verify that submodules are checked out
1 parent b5b739b commit b7ca9b5

File tree

1 file changed

+70
-2
lines changed

1 file changed

+70
-2
lines changed

setup.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import glob
77
import os
88
import subprocess
9+
import sys
10+
import time
911
from datetime import datetime
1012

1113
from setuptools import find_packages, setup
@@ -58,6 +60,71 @@ def read_version(file_path="version.txt"):
5860
CUDAExtension,
5961
)
6062

63+
# Constant known variables used throughout this file
64+
cwd = os.path.abspath(os.path.curdir)
65+
third_party_path = os.path.join(cwd, "third_party")
66+
67+
68+
def get_submodule_folders():
69+
git_modules_path = os.path.join(cwd, ".gitmodules")
70+
default_modules_path = [
71+
os.path.join(third_party_path, name)
72+
for name in [
73+
"cutlass",
74+
]
75+
]
76+
if not os.path.exists(git_modules_path):
77+
return default_modules_path
78+
with open(git_modules_path) as f:
79+
return [
80+
os.path.join(cwd, line.split("=", 1)[1].strip())
81+
for line in f
82+
if line.strip().startswith("path")
83+
]
84+
85+
86+
def check_submodules():
87+
def check_for_files(folder, files):
88+
if not any(os.path.exists(os.path.join(folder, f)) for f in files):
89+
print("Could not find any of {} in {}".format(", ".join(files), folder))
90+
print("Did you run 'git submodule update --init --recursive'?")
91+
sys.exit(1)
92+
93+
def not_exists_or_empty(folder):
94+
return not os.path.exists(folder) or (
95+
os.path.isdir(folder) and len(os.listdir(folder)) == 0
96+
)
97+
98+
if bool(os.getenv("USE_SYSTEM_LIBS", False)):
99+
return
100+
folders = get_submodule_folders()
101+
# If none of the submodule folders exists, try to initialize them
102+
if all(not_exists_or_empty(folder) for folder in folders):
103+
try:
104+
print(" --- Trying to initialize submodules")
105+
start = time.time()
106+
subprocess.check_call(
107+
["git", "submodule", "update", "--init", "--recursive"], cwd=cwd
108+
)
109+
end = time.time()
110+
print(f" --- Submodule initialization took {end - start:.2f} sec")
111+
except Exception:
112+
print(" --- Submodule initalization failed")
113+
print("Please run:\n\tgit submodule update --init --recursive")
114+
sys.exit(1)
115+
for folder in folders:
116+
check_for_files(
117+
folder,
118+
[
119+
"CMakeLists.txt",
120+
"Makefile",
121+
"setup.py",
122+
"LICENSE",
123+
"LICENSE.md",
124+
"LICENSE.txt",
125+
],
126+
)
127+
61128

62129
def get_extensions():
63130
debug_mode = os.getenv("DEBUG", "0") == "1"
@@ -106,8 +173,7 @@ def get_extensions():
106173
use_cutlass = False
107174
if use_cuda and not IS_WINDOWS:
108175
use_cutlass = True
109-
this_dir = os.path.abspath(os.path.curdir)
110-
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
176+
cutlass_dir = os.path.join(third_party_path, "cutlass")
111177
cutlass_include_dir = os.path.join(cutlass_dir, "include")
112178
if use_cutlass:
113179
extra_compile_args["nvcc"].extend(
@@ -145,6 +211,8 @@ def get_extensions():
145211
return ext_modules
146212

147213

214+
check_submodules()
215+
148216
setup(
149217
name="torchao",
150218
version=version + version_suffix,

0 commit comments

Comments
 (0)