Skip to content

Commit 5e182db

Browse files
[Backend] Add metal backend (tile-ai#799)
* Reset * Fix other CUDA issue * fmt * fmt * fix cuda error * fix * fix * fmt * cleanup * fix * remove copyright * trivial update * readme update * lint fix --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 009be27 commit 5e182db

File tree

29 files changed

+575
-69
lines changed

29 files changed

+575
-69
lines changed

.github/workflows/metal_ci.yml

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
name: CI Test on Metal
2+
on: [pull_request]
3+
4+
env:
5+
PYTHON_VERSION: '3.12'
6+
VENV_DIR: tilelang_ci
7+
8+
jobs:
9+
format-check:
10+
runs-on: [macos-latest]
11+
12+
permissions:
13+
contents: write
14+
15+
steps:
16+
- name: Checkout repository
17+
uses: actions/checkout@v4
18+
with:
19+
fetch-depth: 0
20+
submodules: recursive
21+
22+
- name: Install python via uv
23+
uses: astral-sh/setup-uv@v6
24+
with:
25+
enable-cache: true
26+
ignore-nothing-to-cache: true
27+
activate-environment: true
28+
python-version: ${{ env.PYTHON_VERSION }}
29+
30+
- name: Ensure venv (local & persistent)
31+
run: |
32+
[[ -f requirements-test.txt ]] && \
33+
uv pip install -r requirements-test.txt --no-build-isolation
34+
35+
- name: Run format check
36+
run: |
37+
set -ex
38+
mkdir -p build
39+
# run cmake to create the build directory with compile_commands.json
40+
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_METAL=ON; cd ..
41+
if ! output=$(./format.sh 2>&1); then
42+
echo "------------------------------------"
43+
echo "message:"
44+
echo "$output"
45+
printf '%s\n' "$output"
46+
echo "------------------------------------"
47+
exit 1
48+
fi
49+
50+
build-test-metal:
51+
runs-on: [macos-latest]
52+
needs: format-check
53+
permissions:
54+
contents: read
55+
env:
56+
CMAKE_C_COMPILER_LAUNCHER: ccache
57+
CMAKE_CXX_COMPILER_LAUNCHER: ccache
58+
steps:
59+
- name: Checkout repository
60+
uses: actions/checkout@v4
61+
with:
62+
fetch-depth: 1
63+
submodules: recursive
64+
65+
- name: ccache
66+
uses: hendrikmuhs/ccache-action@v1.2
67+
with:
68+
create-symlink: true
69+
key: ${{ github.job }}-${{ matrix.os }}
70+
71+
- name: Install python via uv
72+
uses: astral-sh/setup-uv@v6
73+
with:
74+
enable-cache: true
75+
ignore-nothing-to-cache: true
76+
activate-environment: true
77+
python-version: ${{ env.PYTHON_VERSION }}
78+
79+
- name: Ensure venv (local & persistent)
80+
run: uv pip install -r requirements-test.txt -r requirements-build.txt
81+
82+
- name: Build wheel
83+
run: |
84+
source .venv/bin/activate
85+
uv pip install -v --no-build-isolation .
86+
87+
- name: Run metal test
88+
run: |
89+
cd testing/python
90+
unset PYTHONPATH
91+
python -m pytest -k metal -v -r fE --durations=0 --timeout=3600

CMakeLists.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,21 @@ endif()
108108
if(DEFINED TVM_PREBUILD_PATH)
109109
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
110110
add_library(tvm SHARED IMPORTED)
111+
find_library(TVM_LIBRARY_LOCATION
112+
NAMES tvm
113+
HINTS "${TVM_PREBUILD_PATH}"
114+
)
111115
set_target_properties(tvm PROPERTIES
112-
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm.so"
116+
IMPORTED_LOCATION "${TVM_LIBRARY_LOCATION}"
113117
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
114118
)
115119
add_library(tvm_runtime SHARED IMPORTED)
120+
find_library(TVM_RUNTIME_LIBRARY_LOCATION
121+
NAMES tvm_runtime
122+
HINTS "${TVM_PREBUILD_PATH}"
123+
)
116124
set_target_properties(tvm_runtime PROPERTIES
117-
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm_runtime.so"
125+
IMPORTED_LOCATION "${TVM_RUNTIME_LIBRARY_LOCATION}"
118126
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
119127
)
120128
else()
@@ -157,6 +165,13 @@ if(USE_ROCM)
157165
list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS})
158166
endif()
159167

168+
if(USE_METAL)
169+
tilelang_file_glob(GLOB TILE_LANG_METAL_SRCS
170+
src/target/rt_mod_metal.cc
171+
)
172+
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
173+
endif()
174+
160175
message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
161176

162177
# Add TileLang object library
@@ -221,6 +236,9 @@ target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS)
221236
# Shared library
222237
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
223238
target_link_libraries(tilelang PUBLIC tvm_runtime)
239+
if(USE_METAL)
240+
target_link_libraries(tilelang PUBLIC tvm)
241+
endif()
224242

225243
# Static library
226244
add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
1313
<img src=./images/MatmulExample.png />
1414

1515
## Latest News
16+
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
1617
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
1718
Check out the preview here:
1819
🔗 [link](https://github.com/tile-ai/tilelang-ascend).

install_metal.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
set -eux
4+
5+
git submodule update --init --recursive
6+
7+
rm -rf build
8+
9+
mkdir build
10+
cp 3rdparty/tvm/cmake/config.cmake build
11+
cd build
12+
13+
echo "set(USE_METAL ON)" >> config.cmake
14+
15+
CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache cmake ..
16+
17+
CORES=$(sysctl -n hw.logicalcpu)
18+
MAKE_JOBS=$(( CORES / 2 ))
19+
make -j${MAKE_JOBS}

setup.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,60 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35+
36+
def _read_bool_env(name: str, default: bool = False) -> bool:
37+
if env := os.environ.get(name):
38+
env = env.lower()
39+
if env in ['on', '1', 'true']:
40+
return True
41+
elif env in ['', 'off', '0', 'false']:
42+
return False
43+
return default
44+
45+
3546
# Environment variables False/True
36-
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
47+
PYPI_BUILD = _read_bool_env('PYPI_BUILD')
3748
PACKAGE_NAME = "tilelang"
3849
ROOT_DIR = os.path.dirname(__file__)
3950

51+
CYCACHE = Path(os.path.join(ROOT_DIR, "tilelang", "jit", "adapter", "cython", ".cycache"))
52+
if not CYCACHE.exists():
53+
# tvm may needs this, we won't always build cython backend so mkdir here.
54+
CYCACHE.mkdir(exist_ok=True)
55+
56+
IS_LINUX = platform.system() == 'Linux'
57+
MAYBE_METAL = platform.mac_ver()[2] == 'arm64'
58+
4059
# Add LLVM control environment variable
41-
USE_LLVM = os.environ.get("USE_LLVM", "False").lower() == "true"
60+
USE_LLVM = _read_bool_env('USE_LLVM')
61+
# Add ROCM control environment variable
62+
USE_ROCM = _read_bool_env("USE_ROCM")
4263
# Add ROCM control environment variable
43-
USE_ROCM = os.environ.get("USE_ROCM", "False").lower() == "true"
64+
USE_METAL = _read_bool_env("USE_METAL", MAYBE_METAL)
65+
# Add ROCM control environment variable
66+
USE_CUDA = _read_bool_env("USE_CUDA", IS_LINUX and not USE_ROCM)
4467
# Build with Debug mode
45-
DEBUG_MODE = os.environ.get("DEBUG_MODE", "False").lower() == "true"
68+
DEBUG_MODE = _read_bool_env('DEBUG_MODE')
4669
# Include commit ID in wheel filename and package metadata
47-
WITH_COMMITID = os.environ.get("WITH_COMMITID", "True").lower() == "true"
70+
WITH_COMMITID = _read_bool_env("WITH_COMMITID")
71+
72+
TVM_PREBUILD_ITEMS = [
73+
"libtvm_runtime.so",
74+
"libtvm.so",
75+
"libtilelang.so",
76+
"libtilelang_module.so",
77+
] if IS_LINUX else [
78+
"libtvm_runtime.dylib",
79+
"libtvm.dylib",
80+
"libtilelang.dylib",
81+
"libtilelang_module.dylib",
82+
]
83+
84+
# from tvm's internal cython?
85+
TVM_PREBUILD_ITEMS_TO_DELETE = [] if IS_LINUX else [
86+
'libtvm_runtime.dylib.dSYM',
87+
'libtvm.dylib.dSYM',
88+
]
4889

4990

5091
def load_module_from_path(module_name, path):
@@ -65,24 +106,17 @@ def load_module_from_path(module_name, path):
65106
raise ValueError(
66107
"ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.")
67108

68-
if not USE_ROCM and not CUDA_HOME:
109+
if USE_CUDA and not CUDA_HOME:
69110
raise ValueError(
70-
"CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected.")
111+
"CUDA support is enabled by default on linux if `USE_ROCM=False`," \
112+
" but CUDA_HOME is not set or detected.")
71113

72114
# Ensure one of CUDA or ROCM is available
73-
if not (CUDA_HOME or ROCM_HOME):
115+
if IS_LINUX and not (CUDA_HOME or ROCM_HOME):
74116
raise ValueError(
75117
"Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)."
76118
)
77119

78-
# TileLang only supports Linux platform
79-
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
80-
81-
82-
def _is_linux_like():
83-
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
84-
sys.platform.startswith("freebsd"))
85-
86120

87121
def get_path(*filepath) -> str:
88122
return os.path.join(ROOT_DIR, *filepath)
@@ -144,7 +178,9 @@ def get_rocm_version():
144178
return Version("5.0.0")
145179

146180

147-
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
181+
def get_tilelang_version(with_cuda=USE_CUDA,
182+
with_system_info=not MAYBE_METAL,
183+
with_commit_id=False) -> str:
148184
version = find_version(get_path(".", "VERSION"))
149185
local_version_parts = []
150186
if with_system_info:
@@ -194,9 +230,6 @@ def get_cplus_compiler():
194230
The path to the default C/C++ compiler, or None if none was found.
195231
"""
196232

197-
if not _is_linux_like():
198-
return None
199-
200233
env_cxx = os.environ.get("CXX") or os.environ.get("CC")
201234
if env_cxx:
202235
return env_cxx
@@ -371,6 +404,8 @@ def patch_libs(libpath):
371404
and have a hard-coded rpath.
372405
Set rpath to the directory of libs so auditwheel works well.
373406
"""
407+
if not IS_LINUX:
408+
return
374409
# check if patchelf is installed
375410
# find patchelf in the system
376411
patchelf_path = shutil.which("patchelf")
@@ -432,13 +467,6 @@ def run(self):
432467
os.makedirs(target_dir)
433468
shutil.copy2(source_dir, target_dir)
434469

435-
TVM_PREBUILD_ITEMS = [
436-
"libtvm_runtime.so",
437-
"libtvm.so",
438-
"libtilelang.so",
439-
"libtilelang_module.so",
440-
]
441-
442470
potential_dirs = [
443471
ext_output_dir,
444472
self.build_lib,
@@ -468,6 +496,14 @@ def run(self):
468496
else:
469497
logger.info(f"WARNING: {item} not found in any expected directories!")
470498

499+
for item in TVM_PREBUILD_ITEMS_TO_DELETE:
500+
source_lib_file = None
501+
for dir in potential_dirs:
502+
candidate = os.path.join(dir, item)
503+
if os.path.exists(candidate):
504+
shutil.rmtree(candidate)
505+
break
506+
471507
TVM_CONFIG_ITEMS = [
472508
f"{build_temp_dir}/config.cmake",
473509
]
@@ -587,10 +623,10 @@ class CMakeExtension(Extension):
587623
:param sourcedir: Directory containing the top-level CMakeLists.txt.
588624
"""
589625

590-
def __init__(self, name, sourcedir=""):
626+
def __init__(self, name, sourcedir="", **kwargs):
591627
# We pass an empty 'sources' list because
592628
# the actual build is handled by CMake, not setuptools.
593-
super().__init__(name=name, sources=[])
629+
super().__init__(name=name, sources=[], **kwargs)
594630

595631
# Convert the source directory to an absolute path
596632
# so that CMake can correctly locate the CMakeLists.txt.
@@ -642,7 +678,7 @@ def run(self):
642678
# To make it works with editable install,
643679
# we need to copy the lib*.so files to the tilelang/lib directory
644680
import glob
645-
files = glob.glob("*.so")
681+
files = glob.glob("*.so" if IS_LINUX else "*.dylib")
646682
if os.path.exists(PACKAGE_NAME):
647683
target_lib_dir = os.path.join(PACKAGE_NAME, "lib")
648684
for file in files:
@@ -724,7 +760,10 @@ def build_cython(self, ext):
724760
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
725761
python_include_path = sysconfig.get_path("include")
726762
cc = get_cplus_compiler()
763+
if MAYBE_METAL:
764+
cc += ' -Wl,-undefined,dynamic_lookup'
727765
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
766+
logger.info(command)
728767
os.system(command)
729768

730769
# rename the temp file to the library file
@@ -783,7 +822,7 @@ def build_cmake(self, ext):
783822
"-G",
784823
"Ninja",
785824
]
786-
if not USE_ROCM:
825+
if USE_CUDA and not USE_ROCM:
787826
cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}")
788827

789828
# Create the temporary build directory (if it doesn't exist).
@@ -804,12 +843,17 @@ def build_cmake(self, ext):
804843
content_lines.append(f"set(USE_LLVM {llvm_config_path})")
805844

806845
# Append GPU backend configuration based on environment
807-
if USE_ROCM:
846+
if USE_METAL:
847+
content_lines += [
848+
"set(USE_METAL ON)",
849+
"set(USE_ROCM OFF)",
850+
]
851+
elif USE_ROCM:
808852
content_lines += [
809853
f"set(USE_ROCM {ROCM_HOME})",
810854
"set(USE_CUDA OFF)",
811855
]
812-
else:
856+
elif CUDA_HOME:
813857
content_lines += [
814858
f"set(USE_CUDA {CUDA_HOME})",
815859
"set(USE_ROCM OFF)",
@@ -846,6 +890,12 @@ def build_cmake(self, ext):
846890
cwd=build_temp)
847891

848892

893+
ext_modules = [
894+
CMakeExtension("TileLangCXX", sourcedir="."),
895+
]
896+
if not MAYBE_METAL:
897+
ext_modules.append(CythonExtension("TileLangCython", sourcedir="."))
898+
849899
setup(
850900
name=PACKAGE_NAME,
851901
version=(get_tilelang_version(with_cuda=False, with_system_info=False, with_commit_id=False)

0 commit comments

Comments
 (0)