Skip to content

[LAYOUTS] Generic stmatrix lowering #6609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

[LAYOUTS] Generic stmatrix lowering #6609

wants to merge 11 commits into from

Conversation

lezcano
Copy link
Contributor

@lezcano lezcano commented Apr 25, 2025

We use divideLeft to lower a generic local_store using stmatrix
whenever possible.

We implement ColumnAction as a helper that allows us to permute
the bases and values and remove broadcasting in a generic way.

The current codegen for stmatrix has some limitations, so we just
return early in those cases. We'll fix those in a future PR.

@lezcano lezcano requested review from Jokeren and ptillet as code owners April 25, 2025 14:01
@lezcano lezcano changed the title [LAYOUTS] Lower stmatrix generically [LAYOUTS] Generic stmatrix lowering Apr 25, 2025
auto regBase = applyLinearLayout(loc, rewriter, quot,
{{kReg, b.lshr(laneId, b.i32_val(3))},
{kLane, b.and_(laneId, b.i32_val(0x7))},
{kWarp, warpId}})[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here, seems like some abuse is going on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abuse? This is just implementing the address map that ldmatrix asks for:
image
In particular, the lower three bits of the lane should map to the columns, while the top two bits should map to the 4 different matrices (given by the first 2 basis of the reps of quot). I'll write a comment.

Copy link
Contributor

@peterbell10 peterbell10 Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you're setting the register index based on the lane id so clearly you're abusing the labels kReg and kLane to mean something else. That's pretty confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit concerned that this new feature that is supposed to simplify all our lowerings is resulting in (for me) unreadable code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I understand now. When using .x4 you pass the pointer corresponding to the row for register T0:r[i] in T(0 + 8*i):p so the lane really is giving the register index. Makes sense. I think I would restate your comment though as:

  // Here we implement the stmatrix.x4 addressing. In particular, the row pointers
  // for each submatrix r in thread t are communicated by the stmatrix call in 
  // laneId = (t // 8) + (8 * r), so r = laneId // 8 and t = laneId % 8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote it in a third way, tell me WDYT.

@lezcano
Copy link
Contributor Author

lezcano commented Apr 25, 2025

fwiw, before merging, I'll implement something so that we just lower via this path if the lowering would have no bank conflicts. Just to make sure we don't over-use this function. I'm also missing to run benchmarks.

Copy link
Contributor

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. The only comment I have is that permuteInDimToFront can be very useful in other places where we need register permutation so probably better to put in LayoutUtils.h

@Jokeren
Copy link
Contributor

Jokeren commented Apr 26, 2025

Also I think stmatrix use in convertlayout op lowering hasn't been replaced yet. Could be done in another PR though

}

std::optional<ColumnAction> actionDivideLeft(const LinearLayout &A,
const LinearLayout &B) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is more like findRegPermutationThatDividesLeft?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I was just looking for a shorter name as this function is going to be used quite a bit. I was also thinking of packing this + the division and action on the regs into a helper function. Might do that tomorrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chose a shorter but representative name

// first submatrix, threads 8-15 for the second submatrix, etc. In general we
// map:
// - The lowest 3 bits of the thread id to the columns of each submatrix
// - The top 2 bits to the submatrix number (which is indexed by the next 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find this very helpful because it relates thread id to columns of the matrix, but not to register and lane id which is what you're giving to the linear layout. So this comment is correct, but doesn't explain the code at all.

Copy link
Contributor Author

@lezcano lezcano Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapping to the 2 bits into the registers is explained in the parens in this line you comment.
The only thing that's missing is noting that each column i starts with thread t[4*i], and since the quotient has already removed the 4 threads, then on this layout hte map is i -> t[i].
The second part is a bit redundant once you have intuition for what divideLeft does tho.

Would you want me to add that?

Copy link
Contributor

@peterbell10 peterbell10 Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah so kLane refers to the lane id with two bases removed such that it's not really a lane id any more. This is why you refer to them by different names now in the comment, but not in the code.

Perhaps It would be less confusing to always act on the "reps" layout? idk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case that'd be equivalent to the current state, generating the same number of ops.

Now, the better codegen would be to simply create a new layout composed of:

  • 3 zero bases and he first vec non-zero bases of quot[kReg] (2 int his case)
  • The first 3 non-zero basas of `quot[kLane]
  • All the warp bases
    We can call these dimensions kVec, kCol, kWarp respectively

This would make sure that you can just pass

applyLinearLayout(loc, rewriter, quot,
                                   {{kVec, laneId},
                                    {kCol, laneId},
                                    {kWarp, warpId}})[0]

This linear layout has just only 5 ones in the columns of kReg and kLane so it would generate pretty low opcount (note that now there are more ones in the kReg dim as we didn't trim them. LLVM may be able to optimise these, but who knows)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving things as-is for now tho. The main thing to note is that, after dividing by a layout, the labels on the resulting layout do not refer to the hardware coords (neither the offsets on the output refer to the actual offset) but they refer to equivalence classes.

In CS terms, this would be like when you have a float ptr and doing increments of 1 with operator[] moves you 4 bytes. Then a division would be like casting it to float4 (grouping a tile of 4 together). Then, the result of the division tells you how to move around in the float4 world, where you have collapsed 4 floats into one element. In particular, in this world, moving 1 moves you 16 bytes.

This is what we are doing but with a more complex tile, where moving 1 over a thread moves us 8 offsets (i.e. one full tile, as per the line regBase = b.shl(regBase, b.i32_val(tile.getTotalOutDimSizeLog2()));)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I implemented it. Now the codegen should be better (on our end at least) and hopefully the explanation for why we create the layout should be clearer. WDYT

lezcano added 7 commits April 29, 2025 10:57
We use `divideLeft` to lower a generic `local_store` using `stmatrix`
whenever possible.

The current codegen for `stmatrix` has some limitations, so we just
return early in those cases. We'll fix those in a future PR.
Generalise lowering by accepting permuted register layouts
Simplify the lowering by using ColumnAction.
if (!sharedLayout)

// Inter block stmatrix is not supported
if (cvt.hasInDim(kBlock))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @peterbell10
There was a test in test_tensor_descriptor.py that was not passing with 2CTAs. Now with this line passes. This line says that we just bail out if there's any interCTA business going on. It seems to work just fine as long as you are addressing things within your own CTA, even without any map or anything.

@lezcano
Copy link
Contributor Author

lezcano commented Apr 29, 2025

This is ready for review:

  • It's missing benchmarking, will do tomorrow
  • The convertlayout port I'll do it in a different PR, it should be easy.
  • The generalisation of the stmatrix to non bf16/f16, other vectorizations, transpose, I'll do in a different PR

Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My remaining comments are non-blocking, we can chat offline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants