Skip to content
63 changes: 29 additions & 34 deletions python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@
import sysconfig
from distutils.sysconfig import get_python_lib
from . import utils
from .utils.tools import flagtree_configs as configs

extend_backends = []
default_backends = ["nvidia", "amd"]
plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"]
ext_sourcedir = "triton/_C/"
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend", "sunrise": "sunrise"}
activated_module = utils.activate(flagtree_backend)
downloader = utils.tools.DownloadManager()
configs = configs
flagtree_backend = configs.flagtree_backend

set_llvm_env = lambda path: set_env(
{
Expand All @@ -37,7 +31,7 @@ def yield_backend_dirs():

def install_extension(*args, **kargs):
try:
activated_module.install_extension(*args, **kargs)
configs.activated_module.install_extension(*args, **kargs)
except Exception:
pass

Expand All @@ -50,7 +44,7 @@ def get_backend_cmake_args(*args, **kargs):
handle_plugin_backend(editable)
try:
# cmake_args = configs.activated_module.get_backend_cmake_args(*args, **kargs)
cmake_args = activated_module.get_backend_cmake_args(*args, **kargs)
cmake_args = configs.activated_module.get_backend_cmake_args(*args, **kargs)
except Exception:
cmake_args = []
if editable:
Expand All @@ -59,13 +53,13 @@ def get_backend_cmake_args(*args, **kargs):


def get_device_name():
return device_mapping[flagtree_backend]
return configs.device_alias_map[flagtree_backend]


def get_extra_packages():
packages = []
try:
packages = activated_module.get_extra_install_packages()
packages = configs.activated_module.get_extra_install_packages()
except Exception:
packages = []
return packages
Expand All @@ -74,7 +68,7 @@ def get_extra_packages():
def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
try:
package_data += activated_module.get_package_data_tools()
package_data += configs.activated_module.get_package_data_tools()
except Exception:
package_data
return package_data
Expand All @@ -100,16 +94,16 @@ def download_flagtree_third_party(name, condition, required=False, hock=None):
submodule = utils.flagtree_submodules[name]
downloader.download(module=submodule, required=required)
if callable(hock):
hock(third_party_base_dir=utils.flagtree_submodule_dir, backend=submodule,
default_backends=default_backends)
configs.default_backends = hock(third_party_base_dir=utils.flagtree_configs.flagtree_submodule_dir,
backend=submodule, default_backends=configs.default_backends)

else:
print(f"\033[1;33m[Note] Skip downloading {name} since USE_{name.upper()} is set to OFF\033[0m")


def post_install():
try:
activated_module.post_install()
configs.activated_module.post_install()
except Exception:
pass

Expand Down Expand Up @@ -270,45 +264,45 @@ def skip_package_dir(package):
if 'backends' in package or 'profiler' in package:
return True
try:
return activated_module.skip_package_dir(package)
return configs.activated_module.skip_package_dir(package)
except Exception:
return False

@staticmethod
def get_package_dir(packages):
package_dict = {}
if flagtree_backend and flagtree_backend not in plugin_backends:
if configs.flagtree_backend and configs.flagtree_backend not in configs.plugin_backends:
connection = []
backend_triton_path = f"./third_party/{flagtree_backend}/python/"
backend_triton_path = f"./third_party/{configs.flagtree_backend}/python/"
for package in packages:
if CommonUtils.skip_package_dir(package):
continue
pair = (package, f"{backend_triton_path}{package}")
connection.append(pair)
package_dict.update(connection)
try:
package_dict.update(activated_module.get_package_dir())
package_dict.update(configs.activated_module.get_package_dir())
except Exception:
pass
return package_dict


def handle_flagtree_backend():
global ext_sourcedir
if flagtree_backend:
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends:
ext_sourcedir = os.path.abspath(f"./third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
if configs.flagtree_backend:
print(f"\033[1;32m[INFO] FlagtreeBackend is {configs.flagtree_backend}\033[0m")
configs.extend_backends.append(configs.flagtree_backend)
if "editable_wheel" in sys.argv and configs.flagtree_backend not in configs.plugin_backends:
ext_sourcedir = os.path.abspath(f"./third_party/{configs.flagtree_backend}/python/{ext_sourcedir}") + "/"


def handle_plugin_backend(editable):
plugin_mode = os.getenv("FLAGTREE_PLUGIN")
if plugin_mode and plugin_mode.upper() not in ["0", "OFF"]:
if (plugin_mode and plugin_mode.upper() not in ["0", "OFF"]) or not configs.flagtree_backend:
return
flagtree_backend_dir = Path.home() / ".flagtree" / flagtree_backend
flagtree_plugin_so = flagtree_backend + "TritonPlugin.so"
if flagtree_backend in ["iluvatar", "mthreads", "sunrise"]:
flagtree_backend_dir = Path.home() / ".flagtree" / configs.flagtree_backend
flagtree_plugin_so = configs.flagtree_backend + "TritonPlugin.so"
if configs.flagtree_backend in ["iluvatar", "mthreads", "sunrise"]:
if editable is False:
src_build_plugin_path = flagtree_backend_dir / flagtree_plugin_so
dst_build_plugin_dir = Path(sysconfig.get_path("purelib")) / "triton" / "_C"
Expand Down Expand Up @@ -384,9 +378,10 @@ def uninstall_triton():
)

cache.store(
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not flagtree_plugin), url=
file="iluvatarTritonPlugin.so", condition=("iluvatar" == configs.flagtree_backend)
and (not configs.flagtree_plugin), url=
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64_v0.3.0.tar.gz",
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="015b9af8")
copy_dst_path=f"third_party/{configs.flagtree_backend}", md5_digest="015b9af8")

# klx xpu
cache.store(
Expand Down Expand Up @@ -423,7 +418,7 @@ def uninstall_triton():
)

cache.store(
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not flagtree_plugin), url=
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not configs.flagtree_plugin), url=
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/mthreadsTritonPlugin-cpython3.10-glibc2.35-glibcxx3.4.30-cxxabi1.3.13-ubuntu-x86_64_v0.3.0.tar.gz",
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="2a9ca0b8")

Expand Down Expand Up @@ -498,6 +493,6 @@ def sunrise_set_llvm_env(path):
)

cache.store(
file="sunriseTritonPlugin.so", condition=("sunrise" == flagtree_backend) and (not flagtree_plugin), url=
file="sunriseTritonPlugin.so", condition=("sunrise" == flagtree_backend) and (not configs.flagtree_plugin), url=
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/sunriseTritonPlugin-cpython3.10-glibc2.39-glibcxx3.4.33-x86_64_v0.4.0.tar.gz",
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="1f0b7e67")
6 changes: 3 additions & 3 deletions python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import importlib.util
import os
from . import tools, default, aipu
from .tools import flagtree_submodule_dir, OfflineBuildManager
from .tools import flagtree_configs, OfflineBuildManager

flagtree_submodules = {
"triton_shared":
tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
commit_id="5842469a16b261e45a2c67fbfc308057622b03ee",
dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")),
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")),
"flir":
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
dst_path=os.path.join(flagtree_submodule_dir, "flir")),
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")),
}


Expand Down
5 changes: 4 additions & 1 deletion python/setup_tools/utils/aipu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
def precompile_hock(*args, **kargs):
default_backends = kargs["default_backends"]
default_backends.append('flir')
default_backends_list = [*default_backends, "flir"]
kargs["default_backends"] = tuple(default_backends_list)
default_backends = tuple(default_backends_list)
return default_backends
62 changes: 53 additions & 9 deletions python/setup_tools/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,54 @@

from python.build_helpers import get_base_dir
import platform
from typing import Mapping
from types import MappingProxyType
import importlib.util
from dataclasses import field

flagtree_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
flagtree_submodule_dir = os.path.join(flagtree_root_dir, "third_party")
flagtree_backend = os.environ.get("FLAGTREE_BACKEND")
use_cuda_toolkit = ["aipu"]

def _get_flagtree_root() -> str:
return str(Path(__file__).resolve().parents[3])


@dataclass
class FlagtreeConfigs:
default_backends: tuple = ("nvidia", "amd")
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame")
use_cuda_toolkit_backends: tuple = ('aipu', )
language_extra_backends: tuple = ('xpu', 'mthreads', "cambricon")
ext_sourcedir: str = "triton/_C/"
flagtree_root_dir: str = field(default_factory=_get_flagtree_root)
flagtree_backend: str = field(default_factory=lambda: os.environ.get("FLAGTREE_BACKEND"))
flagtree_plugin: str = field(default_factory=lambda: os.environ.get("FLAGTREE_PLUGIN"))
extend_backends: list = field(default_factory=list)
activated_module: any = None
flagtree_submodule_dir: str = ''
device_alias_map: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({
"xpu": "xpu",
"mthreads": "musa",
"ascend": "ascend",
"cambricon": "mlu",
}))

def __post_init__(self):
self.flagtree_submodule_dir = os.path.join(self.flagtree_root_dir, "third_party")
self.activated_module = self._activate_device_module()

def _activate_device_module(self, suffix=".py"):
backend = self.flagtree_backend or "default"
module_path = Path(os.path.dirname(__file__)) / backend
module_path = str(module_path) + suffix
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except (AttributeError, FileNotFoundError, ImportError, ModuleNotFoundError):
pass
return module


flagtree_configs = FlagtreeConfigs()


@dataclass
Expand All @@ -41,7 +84,8 @@ def dir_rollback(deep, base_path):


def is_skip_cuda_toolkits():
return flagtree_backend and (flagtree_backend not in use_cuda_toolkit)
return flagtree_configs.flagtree_backend and (flagtree_configs.flagtree_backend
not in flagtree_configs.use_cuda_toolkit_backends)


def remove_triton_in_modules(model):
Expand Down Expand Up @@ -216,7 +260,7 @@ def is_offline_build(self) -> bool:
return os.getenv("TRITON_OFFLINE_BUILD", "OFF") == "ON" or os.getenv("FLAGTREE_OFFLINE_BUILD_DIR")

def copy_to_flagtree_project(self, kargs):
dst_path = os.path.join(flagtree_root_dir,
dst_path = os.path.join(_get_flagtree_root(),
kargs['dst_path']) if 'dst_path' in kargs and kargs['dst_path'] else None
src_path = self.src
if not dst_path:
Expand Down Expand Up @@ -265,7 +309,7 @@ def handle_triton_origin_toolkits(self):
shutil.copytree(src_path, toolkit_cache_path, dirs_exist_ok=True)
else:
raise RuntimeError(
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
)

def validate_offline_build_dir(self, path, required=False):
Expand All @@ -280,7 +324,7 @@ def validate_offline_build_deps(self, path, kargs, required=False):
url = kargs.get('url', None)
if (not path or not os.path.exists(path)) and required:
raise RuntimeError(
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
f" And you can download the dependency package from the \n \033[93m{url}\033[0m \n"
f" then extract it to the \033[93m{self.offline_build_dir}\033[0m directory you specified !\033[0m\n\n")

Expand All @@ -301,7 +345,7 @@ def single_build(self, *args, **kargs):
self.copy_to_flagtree_project(kargs)
self.handle_flagtree_hock(kargs)
if is_skip_cuda_toolkits():
print(f"[INFO] Skipping CUDA toolkits for {flagtree_backend} backend in offline build.")
print(f"[INFO] Skipping CUDA toolkits for {flagtree_configs.flagtree_backend} backend in offline build.")
else:
self.handle_triton_origin_toolkits()
return True
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,13 @@ def download_and_copy_dependencies():
if helper.flagtree_backend:
if helper.flagtree_backend in ("aipu", "tsingmicro"):
backends = [
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
*BackendInstaller.copy(helper.configs.default_backends + tuple(helper.configs.extend_backends)),
*BackendInstaller.copy_externals(),
]
else:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
backends = [*BackendInstaller.copy(helper.configs.extend_backends), *BackendInstaller.copy_externals()]
else:
print(helper.default_backends)
print(helper.configs.default_backends)
backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]

#backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]
Expand Down
Loading