3232
3333logger = logging .getLogger (__name__ )
3434
35+
36+ def _read_bool_env (name : str , default : bool = False ) -> bool :
37+ if env := os .environ .get (name ):
38+ env = env .lower ()
39+ if env in ['on' , '1' , 'true' ]:
40+ return True
41+ elif env in ['' , 'off' , '0' , 'false' ]:
42+ return False
43+ return default
44+
45+
3546# Environment variables False/True
36- PYPI_BUILD = os . environ . get ( " PYPI_BUILD" , "False" ). lower () == "true"
47+ PYPI_BUILD = _read_bool_env ( ' PYPI_BUILD' )
3748PACKAGE_NAME = "tilelang"
3849ROOT_DIR = os .path .dirname (__file__ )
3950
51+ CYCACHE = Path (os .path .join (ROOT_DIR , "tilelang" , "jit" , "adapter" , "cython" , ".cycache" ))
52+ if not CYCACHE .exists ():
53+ # tvm may needs this, we won't always build cython backend so mkdir here.
54+ CYCACHE .mkdir (exist_ok = True )
55+
56+ IS_LINUX = platform .system () == 'Linux'
57+ MAYBE_METAL = platform .mac_ver ()[2 ] == 'arm64'
58+
4059# Add LLVM control environment variable
41- USE_LLVM = os .environ .get ("USE_LLVM" , "False" ).lower () == "true"
60+ USE_LLVM = _read_bool_env ('USE_LLVM' )
61+ # Add ROCM control environment variable
62+ USE_ROCM = _read_bool_env ("USE_ROCM" )
4263# Add ROCM control environment variable
43- USE_ROCM = os .environ .get ("USE_ROCM" , "False" ).lower () == "true"
64+ USE_METAL = _read_bool_env ("USE_METAL" , MAYBE_METAL )
65+ # Add ROCM control environment variable
66+ USE_CUDA = _read_bool_env ("USE_CUDA" , IS_LINUX and not USE_ROCM )
4467# Build with Debug mode
45- DEBUG_MODE = os . environ . get ( " DEBUG_MODE" , "False" ). lower () == "true"
68+ DEBUG_MODE = _read_bool_env ( ' DEBUG_MODE' )
4669# Include commit ID in wheel filename and package metadata
47- WITH_COMMITID = os .environ .get ("WITH_COMMITID" , "True" ).lower () == "true"
70+ WITH_COMMITID = _read_bool_env ("WITH_COMMITID" )
71+
72+ TVM_PREBUILD_ITEMS = [
73+ "libtvm_runtime.so" ,
74+ "libtvm.so" ,
75+ "libtilelang.so" ,
76+ "libtilelang_module.so" ,
77+ ] if IS_LINUX else [
78+ "libtvm_runtime.dylib" ,
79+ "libtvm.dylib" ,
80+ "libtilelang.dylib" ,
81+ "libtilelang_module.dylib" ,
82+ ]
83+
84+ # from tvm's internal cython?
85+ TVM_PREBUILD_ITEMS_TO_DELETE = [] if IS_LINUX else [
86+ 'libtvm_runtime.dylib.dSYM' ,
87+ 'libtvm.dylib.dSYM' ,
88+ ]
4889
4990
5091def load_module_from_path (module_name , path ):
@@ -65,24 +106,17 @@ def load_module_from_path(module_name, path):
65106 raise ValueError (
66107 "ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected." )
67108
68- if not USE_ROCM and not CUDA_HOME :
109+ if USE_CUDA and not CUDA_HOME :
69110 raise ValueError (
70- "CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected." )
111+ "CUDA support is enabled by default on linux if `USE_ROCM=False`," \
112+ " but CUDA_HOME is not set or detected." )
71113
72114# Ensure one of CUDA or ROCM is available
73- if not (CUDA_HOME or ROCM_HOME ):
115+ if IS_LINUX and not (CUDA_HOME or ROCM_HOME ):
74116 raise ValueError (
75117 "Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)."
76118 )
77119
78- # TileLang only supports Linux platform
79- assert sys .platform .startswith ("linux" ), "TileLang only supports Linux platform (including WSL)."
80-
81-
82- def _is_linux_like ():
83- return (sys .platform == "darwin" or sys .platform .startswith ("linux" ) or
84- sys .platform .startswith ("freebsd" ))
85-
86120
87121def get_path (* filepath ) -> str :
88122 return os .path .join (ROOT_DIR , * filepath )
@@ -144,7 +178,9 @@ def get_rocm_version():
144178 return Version ("5.0.0" )
145179
146180
147- def get_tilelang_version (with_cuda = True , with_system_info = True , with_commit_id = False ) -> str :
181+ def get_tilelang_version (with_cuda = USE_CUDA ,
182+ with_system_info = not MAYBE_METAL ,
183+ with_commit_id = False ) -> str :
148184 version = find_version (get_path ("." , "VERSION" ))
149185 local_version_parts = []
150186 if with_system_info :
@@ -194,9 +230,6 @@ def get_cplus_compiler():
194230 The path to the default C/C++ compiler, or None if none was found.
195231 """
196232
197- if not _is_linux_like ():
198- return None
199-
200233 env_cxx = os .environ .get ("CXX" ) or os .environ .get ("CC" )
201234 if env_cxx :
202235 return env_cxx
@@ -371,6 +404,8 @@ def patch_libs(libpath):
371404 and have a hard-coded rpath.
372405 Set rpath to the directory of libs so auditwheel works well.
373406 """
407+ if not IS_LINUX :
408+ return
374409 # check if patchelf is installed
375410 # find patchelf in the system
376411 patchelf_path = shutil .which ("patchelf" )
@@ -432,13 +467,6 @@ def run(self):
432467 os .makedirs (target_dir )
433468 shutil .copy2 (source_dir , target_dir )
434469
435- TVM_PREBUILD_ITEMS = [
436- "libtvm_runtime.so" ,
437- "libtvm.so" ,
438- "libtilelang.so" ,
439- "libtilelang_module.so" ,
440- ]
441-
442470 potential_dirs = [
443471 ext_output_dir ,
444472 self .build_lib ,
@@ -468,6 +496,14 @@ def run(self):
468496 else :
469497 logger .info (f"WARNING: { item } not found in any expected directories!" )
470498
499+ for item in TVM_PREBUILD_ITEMS_TO_DELETE :
500+ source_lib_file = None
501+ for dir in potential_dirs :
502+ candidate = os .path .join (dir , item )
503+ if os .path .exists (candidate ):
504+ shutil .rmtree (candidate )
505+ break
506+
471507 TVM_CONFIG_ITEMS = [
472508 f"{ build_temp_dir } /config.cmake" ,
473509 ]
@@ -587,10 +623,10 @@ class CMakeExtension(Extension):
587623 :param sourcedir: Directory containing the top-level CMakeLists.txt.
588624 """
589625
590- def __init__ (self , name , sourcedir = "" ):
626+ def __init__ (self , name , sourcedir = "" , ** kwargs ):
591627 # We pass an empty 'sources' list because
592628 # the actual build is handled by CMake, not setuptools.
593- super ().__init__ (name = name , sources = [])
629+ super ().__init__ (name = name , sources = [], ** kwargs )
594630
595631 # Convert the source directory to an absolute path
596632 # so that CMake can correctly locate the CMakeLists.txt.
@@ -642,7 +678,7 @@ def run(self):
642678 # To make it works with editable install,
643679 # we need to copy the lib*.so files to the tilelang/lib directory
644680 import glob
645- files = glob .glob ("*.so" )
681+ files = glob .glob ("*.so" if IS_LINUX else "*.dylib" )
646682 if os .path .exists (PACKAGE_NAME ):
647683 target_lib_dir = os .path .join (PACKAGE_NAME , "lib" )
648684 for file in files :
@@ -724,7 +760,10 @@ def build_cython(self, ext):
724760 os .system (f"{ cython } { cython_wrapper_path } --cplus -o { source_path } " )
725761 python_include_path = sysconfig .get_path ("include" )
726762 cc = get_cplus_compiler ()
763+ if MAYBE_METAL :
764+ cc += ' -Wl,-undefined,dynamic_lookup'
727765 command = f"{ cc } -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{ python_include_path } { source_path } -o { temp_path } "
766+ logger .info (command )
728767 os .system (command )
729768
730769 # rename the temp file to the library file
@@ -783,7 +822,7 @@ def build_cmake(self, ext):
783822 "-G" ,
784823 "Ninja" ,
785824 ]
786- if not USE_ROCM :
825+ if USE_CUDA and not USE_ROCM :
787826 cmake_args .append (f"-DCMAKE_CUDA_COMPILER={ os .path .join (CUDA_HOME , 'bin' , 'nvcc' )} " )
788827
789828 # Create the temporary build directory (if it doesn't exist).
@@ -804,12 +843,17 @@ def build_cmake(self, ext):
804843 content_lines .append (f"set(USE_LLVM { llvm_config_path } )" )
805844
806845 # Append GPU backend configuration based on environment
807- if USE_ROCM :
846+ if USE_METAL :
847+ content_lines += [
848+ "set(USE_METAL ON)" ,
849+ "set(USE_ROCM OFF)" ,
850+ ]
851+ elif USE_ROCM :
808852 content_lines += [
809853 f"set(USE_ROCM { ROCM_HOME } )" ,
810854 "set(USE_CUDA OFF)" ,
811855 ]
812- else :
856+ elif CUDA_HOME :
813857 content_lines += [
814858 f"set(USE_CUDA { CUDA_HOME } )" ,
815859 "set(USE_ROCM OFF)" ,
@@ -846,6 +890,12 @@ def build_cmake(self, ext):
846890 cwd = build_temp )
847891
848892
893+ ext_modules = [
894+ CMakeExtension ("TileLangCXX" , sourcedir = "." ),
895+ ]
896+ if not MAYBE_METAL :
897+ ext_modules .append (CythonExtension ("TileLangCython" , sourcedir = "." ))
898+
849899setup (
850900 name = PACKAGE_NAME ,
851901 version = (get_tilelang_version (with_cuda = False , with_system_info = False , with_commit_id = False )
0 commit comments