Skip to content

[Frontend][MLIR] Fix crash when MLIR_DUMP_PATH is set to a directory #6549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ arbitrary LLVM version.
environment variable (to the `pip install -e .` command) to limit the
number of jobs.

- If you are modifying the Triton C++ source files, e.g., under `python/src/`,
and encounter runtime errors or failed assertions, such as when generating
MLIR IR dumps, you may need to **fully clean and rebuild** Triton. This ensures
Python uses the updated shared object files (the `*.so` artifacts containing
your recent changes) rather than stale ones.

Run the following from the root of your Triton repository to trigger a clean
rebuild:

```bash
python setup.py clean --all
rm -rf build
find . -iname '*.so' -delete
pip install -e .
```

- Pass `--no-build-isolation` to `pip install` to make nop builds faster.
Without this, every invocation of `pip install` uses a different symlink to
cmake, and this forces ninja to rebuild most of the `.a` files.
Expand Down Expand Up @@ -188,7 +204,10 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs, for all
kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only.
- Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*`
- `MLIR_DUMP_PATH` specifies where `MLIR_ENABLE_DUMP` will dump to. If unset will dump to stderr.
- `MLIR_DUMP_PATH` specifies the destination for MLIR IR dumps created via `MLIR_ENABLE_DUMP`.
- If the value is a **filename**, Triton writes a single `.mlir` file.
- If the value is a **directory**, Triton will write a timestamped `.mlir` file for each kernel.
- If unset, dumps are written to stderr.
- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR.
- `TRITON_REPRODUCER_PATH=<reproducer_path>` will generate an MLIR reproducer file
at `<reproducer_path>` before each MLIR compiler stage. If any of the stages fail,
Expand Down
43 changes: 38 additions & 5 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <cstdlib>
#include <memory>
#include <optional>

#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -45,14 +48,44 @@ namespace tt = triton;
namespace ttg = triton::gpu;
namespace ttng = triton::nvidia_gpu;

// Lazily initialized MLIR dump stream. The stream is initialized only on first
// call using the MLIR_DUMP_PATH environment variable.
std::unique_ptr<llvm::raw_fd_ostream> mlir_dump_stream;

// Return a reference to the MLIR dump stream.
llvm::raw_fd_ostream &mlir_dumps() {
std::error_code EC;
static llvm::raw_fd_ostream S(::triton::tools::getStrEnv("MLIR_DUMP_PATH"),
EC, llvm::sys::fs::CD_CreateAlways);
assert(!EC);
return S;
if (!mlir_dump_stream) {
Copy link
Contributor

@peterbell10 peterbell10 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change ensures the dump path is re-evaluated on each call,
enabling flexible, per-kernel IR dumping.

This obviously isn't true since you only check the env variable on the first call. The only meaningful change here is supporting MLIR_DUMP_PATH as a directory, but in that case why not use TRITON_DUMP_DIR which is probably exactly what you're looking for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback.

This obviously isn't true since you only check the env variable on the first call.

Correct. What I meant was that in mlir_dumps_or_dbgs(), the environment variable is checked on every call, but the dump stream is initialized only once. I will correct the commit message and also add a clarifying comment to this function in the updated commit.

The only meaningful here is supporting MLIR_DUMP_PATH as a directory,

Correct.

but in that case why not use TRITON_DUMP_DIR which is probably exactly what you're looking for.

Actually not. Based on the Tips for hacking section of the README and the contents of the dumped files,
my understanding is that MLIR_DUMP_PATH and TRITON_DUMP_DIR serve different purposes.
MLIR_DUMP_PATH is controlled by MLIR_ENABLE_DUMP and is used to store MLIR-native IR dumps.
TRITON_DUMP_DIR is controlled by TRITON_KERNEL_DUMP and is used to store Triton IR, LLVM IR, and backend binaries, such as .cubin and .ptx files.

The purpose of this PR is to prevent crashes when MLIR_DUMP_PATH is set to a directory,
and to make its behavior consistent with what the README describes.

I'm currently learning the Triton code base, so please let me know if I've missed anything. I'd really appreciate it.

std::error_code EC;
std::string base = ::triton::tools::getStrEnv("MLIR_DUMP_PATH");
std::string filename = base;

// Generate a unique filename within MLIR_DUMP_PATH if it is set to a
// directory.
if (llvm::sys::fs::is_directory(base)) {
std::time_t now = std::time(nullptr);
int pid = getpid();
filename = base + "/triton_dump_" + std::to_string(pid) + "_" +
std::to_string(now) + ".mlir";
}

// Create the output stream to the resolved path.
mlir_dump_stream = std::make_unique<llvm::raw_fd_ostream>(
filename, EC, llvm::sys::fs::CD_CreateAlways);

// Abort the program if the dump file cannot be opened.
if (EC) {
llvm::errs() << "Failed to open IR dump file: " << filename << "\n";
llvm::errs() << "Error: " << EC.message() << "\n";
std::abort();
}
}

return *mlir_dump_stream;
}

// Return a reference to either the dump stream or the debug stream, depending
// on the MLIR_DUMP_PATH environment variable. The environment variable is
// checked on every call, but the dump stream is initialized only once.
llvm::raw_ostream &mlir_dumps_or_dbgs() {
if (!::triton::tools::getStrEnv("MLIR_DUMP_PATH").empty()) {
return mlir_dumps();
Expand Down
58 changes: 58 additions & 0 deletions python/test/unit/test_mlir_dump_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import subprocess
import tempfile
import pytest
import textwrap
from pathlib import Path
import os


@pytest.mark.forked
def test_basic_mlir_dump(monkeypatch):
dump_dir = Path(tempfile.mkdtemp())
monkeypatch.setenv("MLIR_ENABLE_DUMP", "1")
monkeypatch.setenv("MLIR_DUMP_PATH", str(dump_dir))

kernel_code = textwrap.dedent("""
import triton
import triton.language as tl
import torch

@triton.jit
def dummy_kernel(x_ptr, y_ptr, n_elements: tl.constexpr):
pid = tl.program_id(0)
offs = pid * 128 + tl.arange(0, 128)
mask = offs < n_elements
x = tl.load(x_ptr + offs, mask=mask)
tl.store(y_ptr + offs, x, mask=mask)

if __name__ == "__main__":
n = 1024
x = torch.arange(n, dtype=torch.float32, device="cuda")
y = torch.empty_like(x)
dummy_kernel[(n // 128,)](x, y, n_elements=n)
""")

with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
f.write(kernel_code)
script_path = f.name

repo_root = Path(__file__).resolve().parents[2]
env = os.environ.copy()
env["PYTHONPATH"] = str(repo_root) + ":" + env.get("PYTHONPATH", "")

result = subprocess.run(
["python3", script_path],
env=env,
capture_output=True,
text=True,
)

# Diagnostic printing if failure
if result.returncode != 0:
print("STDOUT:\n", result.stdout)
print("STDERR:\n", result.stderr)

assert result.returncode == 0, "Triton kernel script failed"
assert any(f.suffix == ".mlir" for f in dump_dir.iterdir()), "No MLIR dump generated"

os.remove(script_path)