Skip to content

Inconsistent lowering of Torch tril and triu #4604

Description

@rengolin

When lowering KernelBench level1/14_Matmul_for_upper_triangluar_matrices.py and level1/15_Matmul_for_lower_triangluar_matrices.py , I get very different lowering results.

Basically, 15_Matmul_for_lower_triangluar_matrices.py is:

  return torch.tril(torch.matmul(A, B))

Which lowers to a pair of linalg.matmul and a follow up linalg.generic that simply takes the index, compares i and j as expected and selects the correct one. Using upstream MLIR this works as expected, the matmul is tiled, fused and the code runs really fast.

On the other hand, 14_Matmul_for_upper_triangluar_matrices.py is:

  return torch.triu(torch.matmul(A, B))

Which looks like it could produce an identical lowering, just changing the source of the select. But alas, what is produced is a bloated and convoluted sequence. The output is correct, but MLIR has a hard time making heads or tails of it.

The sequence is:

  • A generic that constructs in "index tensor" {0, 1, 2, ..., N-1}
  • Two expand_shape on that tensor to create a column and a row tensors for the next step
  • Another generic that subtracts the row/column tensors to create a 2D tensor <NxNxTy>
  • Another generic that compares each 2D element with a constant zero, creating another tensor <NxNxi1>
  • Another generic that zip the two tensors above, selecting from the latter to zero the former

As is clear, the semantics is the same, but in a weird way.

Now, the last three generics do fuse upon the pass linalg-fuse-elementwise-ops, but since there's two expand_shapes in between, the first one does not. Those don't get converted to a reasonable affine map and bail, leaving the code to materialize a large matrix for no reason.

If there was no other way I could propose some more fiddling with the MLIR pipeline, maybe adding some more transforms. But since the tril does the right thing, I'm wondering why triu does not.

Reproducer

  • Clone Lighthouse
  • Init uv as in README
  • Run: $ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/14_ --print-original-module
  • Run: $ uv run examples/KernelBench/test-kernel-bench.py --kernel level1/15_ --print-original-module
  • Compare the IRs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions