Skip to content

Commit

Permalink
Allow third-party backends to add submodules to `triton.language.extr…
Browse files Browse the repository at this point in the history
…a` (triton-lang#4503)

Add an optional language directory to backends. The contents of the
directory is added to `triton.language.extra` when the wheel is built.
Update the existing `triton.language.extra.cuda` and
`triton.language.extra.hip` modules to use the new mechanism.

The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `It is already tested by
python/test/unit/language/test_core.py::test_math_extern`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
  • Loading branch information
Alfie-Edwards authored and guacamoleo committed Nov 14, 2024
1 parent 19dfb35 commit 481059b
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 11 deletions.
60 changes: 53 additions & 7 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
@dataclass
class Backend:
name: str
package_data: dict
package_data: list[str]
language_package_data: list[str]
src_dir: str
backend_dir: str
language_dir: str
install_dir: str
is_external: bool

Expand Down Expand Up @@ -62,12 +64,22 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend"))
assert os.path.exists(backend_path), f"{backend_path} does not exist!"

language_dir = os.path.abspath(os.path.join(backend_src_dir, "language"))
if not os.path.exists(language_dir):
language_dir = None

for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"

install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path,

language_package_data = []
if language_dir is not None:
language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)]

return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data,
src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
install_dir=install_dir, is_external=is_external)

# Copy all in-tree backends under triton/third_party.
Expand Down Expand Up @@ -556,6 +568,19 @@ def add_link_to_backends():
shutil.rmtree(backend.install_dir)
os.symlink(backend.backend_dir, backend.install_dir)

if backend.language_dir:
# Link the contents of each backend's `language` directory into
# `triton.language.extra`.
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra"))
for x in os.listdir(backend.language_dir):
src_dir = os.path.join(backend.language_dir, x)
install_dir = os.path.join(extra_dir, x)
if os.path.islink(install_dir):
os.unlink(install_dir)
if os.path.exists(install_dir):
shutil.rmtree(install_dir)
os.symlink(src_dir, install_dir)


def add_link_to_proton():
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
Expand Down Expand Up @@ -602,28 +627,49 @@ def run(self):


package_data = {
"triton/tools": ["compile.h", "compile.c"],
**{f"triton/backends/{b.name}": b.package_data
for b in backends},
"triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data
for b in backends}, "triton/language/extra": sum(
(b.language_package_data for b in backends), [])
}


def get_language_extra_packages():
packages = []
for backend in backends:
if backend.language_dir is None:
continue

# Walk the `language` directory of each backend to enumerate
# any subpackages, which will be added to `triton.language.extra`.
for dir, dirs, files in os.walk(backend.language_dir, followlinks=True):
if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir:
# Ignore directories with no python files.
# Also ignore the root directory which corresponds to
# "triton/language/extra".
continue
subpackage = os.path.relpath(dir, backend.language_dir)
package = os.path.join("triton/language/extra", subpackage)
packages.append(package)

return list(packages)


def get_packages():
packages = [
"triton",
"triton/_C",
"triton/compiler",
"triton/language",
"triton/language/extra",
"triton/language/extra/cuda",
"triton/language/extra/hip",
"triton/runtime",
"triton/backends",
"triton/tools",
]
packages += [f'triton/backends/{backend.name}' for backend in backends]
packages += get_language_extra_packages()
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
packages += ["triton/profiler"]

return packages


Expand Down
4 changes: 0 additions & 4 deletions python/triton/language/extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from . import cuda
from . import hip

__all__ = ['cuda', 'hip']
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 481059b

Please sign in to comment.