When lowering KernelBench level1/14_Matmul_for_upper_triangluar_matrices.py and level1/15_Matmul_for_lower_triangluar_matrices.py , I get very different lowering results.
Basically, 15_Matmul_for_lower_triangluar_matrices.py is:
return torch.tril(torch.matmul(A, B))
Which lowers to a pair of linalg.matmul and a follow up linalg.generic that simply takes the index, compares i and j as expected and selects the correct one. Using upstream MLIR this works as expected, the matmul is tiled, fused and the code runs really fast.
On the other hand, 14_Matmul_for_upper_triangluar_matrices.py is:
return torch.triu(torch.matmul(A, B))
Which looks like it could produce an identical lowering, just changing the source of the select. But alas, what is produced is a bloated and convoluted sequence. The output is correct, but MLIR has a hard time making heads or tails of it.
The sequence is:
- A
generic that constructs in "index tensor" {0, 1, 2, ..., N-1}
- Two
expand_shape on that tensor to create a column and a row tensors for the next step
- Another
generic that subtracts the row/column tensors to create a 2D tensor <NxNxTy>
- Another
generic that compares each 2D element with a constant zero, creating another tensor <NxNxi1>
- Another
generic that zip the two tensors above, selecting from the latter to zero the former
As is clear, the semantics is the same, but in a weird way.
Now, the last three generics do fuse upon the pass linalg-fuse-elementwise-ops, but since there's two expand_shapes in between, the first one does not. Those don't get converted to a reasonable affine map and bail, leaving the code to materialize a large matrix for no reason.
If there was no other way I could propose some more fiddling with the MLIR pipeline, maybe adding some more transforms. But since the tril does the right thing, I'm wondering why triu does not.
Reproducer
- Clone Lighthouse
- Init uv as in README
- Run:
$ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/14_ --print-original-module
- Run:
$ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/15_ --print-original-module
- Compare the IRs
When lowering KernelBench
level1/14_Matmul_for_upper_triangluar_matrices.pyandlevel1/15_Matmul_for_lower_triangluar_matrices.py, I get very different lowering results.Basically,
15_Matmul_for_lower_triangluar_matrices.pyis:Which lowers to a pair of
linalg.matmuland a follow uplinalg.genericthat simply takes the index, comparesiandjas expected and selects the correct one. Using upstream MLIR this works as expected, thematmulis tiled, fused and the code runs really fast.On the other hand,
14_Matmul_for_upper_triangluar_matrices.pyis:Which looks like it could produce an identical lowering, just changing the source of the
select. But alas, what is produced is a bloated and convoluted sequence. The output is correct, but MLIR has a hard time making heads or tails of it.The sequence is:
genericthat constructs in "index tensor"{0, 1, 2, ..., N-1}expand_shapeon that tensor to create a column and a row tensors for the next stepgenericthat subtracts the row/column tensors to create a 2D tensor<NxNxTy>genericthat compares each 2D element with a constant zero, creating another tensor<NxNxi1>genericthat zip the two tensors above, selecting from the latter to zero the formerAs is clear, the semantics is the same, but in a weird way.
Now, the last three
generics do fuse upon the passlinalg-fuse-elementwise-ops, but since there's twoexpand_shapes in between, the first one does not. Those don't get converted to a reasonable affine map and bail, leaving the code to materialize a large matrix for no reason.If there was no other way I could propose some more fiddling with the MLIR pipeline, maybe adding some more transforms. But since the
trildoes the right thing, I'm wondering whytriudoes not.Reproducer
$ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/14_ --print-original-module$ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/15_ --print-original-module