-
Notifications
You must be signed in to change notification settings - Fork 2k
[LAYOUTS] Generic stmatrix lowering #6609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
auto regBase = applyLinearLayout(loc, rewriter, quot, | ||
{{kReg, b.lshr(laneId, b.i32_val(3))}, | ||
{kLane, b.and_(laneId, b.i32_val(0x7))}, | ||
{kWarp, warpId}})[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here, seems like some abuse is going on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you're setting the register index based on the lane id so clearly you're abusing the labels kReg
and kLane
to mean something else. That's pretty confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit concerned that this new feature that is supposed to simplify all our lowerings is resulting in (for me) unreadable code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I understand now. When using .x4
you pass the pointer corresponding to the row for register T0:r[i]
in T(0 + 8*i):p
so the lane really is giving the register index. Makes sense. I think I would restate your comment though as:
// Here we implement the stmatrix.x4 addressing. In particular, the row pointers
// for each submatrix r in thread t are communicated by the stmatrix call in
// laneId = (t // 8) + (8 * r), so r = laneId // 8 and t = laneId % 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote it in a third way, tell me WDYT.
fwiw, before merging, I'll implement something so that we just lower via this path if the lowering would have no bank conflicts. Just to make sure we don't over-use this function. I'm also missing to run benchmarks. |
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. The only comment I have is that permuteInDimToFront
can be very useful in other places where we need register permutation so probably better to put in LayoutUtils.h
Also I think stmatrix use in convertlayout op lowering hasn't been replaced yet. Could be done in another PR though |
lib/Tools/LayoutUtils.cpp
Outdated
} | ||
|
||
std::optional<ColumnAction> actionDivideLeft(const LinearLayout &A, | ||
const LinearLayout &B) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this is more like findRegPermutationThatDividesLeft
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I was just looking for a shorter name as this function is going to be used quite a bit. I was also thinking of packing this + the division and action on the regs into a helper function. Might do that tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chose a shorter but representative name
// first submatrix, threads 8-15 for the second submatrix, etc. In general we | ||
// map: | ||
// - The lowest 3 bits of the thread id to the columns of each submatrix | ||
// - The top 2 bits to the submatrix number (which is indexed by the next 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't find this very helpful because it relates thread id to columns of the matrix, but not to register and lane id which is what you're giving to the linear layout. So this comment is correct, but doesn't explain the code at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mapping to the 2 bits into the registers is explained in the parens in this line you comment.
The only thing that's missing is noting that each column i
starts with thread t[4*i]
, and since the quotient has already removed the 4 threads, then on this layout hte map is i -> t[i].
The second part is a bit redundant once you have intuition for what divideLeft does tho.
Would you want me to add that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah so kLane refers to the lane id with two bases removed such that it's not really a lane id any more. This is why you refer to them by different names now in the comment, but not in the code.
Perhaps It would be less confusing to always act on the "reps" layout? idk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case that'd be equivalent to the current state, generating the same number of ops.
Now, the better codegen would be to simply create a new layout composed of:
- 3 zero bases and he first
vec
non-zero bases ofquot[kReg]
(2 int his case) - The first 3 non-zero basas of `quot[kLane]
- All the warp bases
We can call these dimensionskVec, kCol, kWarp
respectively
This would make sure that you can just pass
applyLinearLayout(loc, rewriter, quot,
{{kVec, laneId},
{kCol, laneId},
{kWarp, warpId}})[0]
This linear layout has just only 5 ones in the columns of kReg
and kLane
so it would generate pretty low opcount (note that now there are more ones in the kReg
dim as we didn't trim them. LLVM may be able to optimise these, but who knows)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaving things as-is for now tho. The main thing to note is that, after dividing by a layout, the labels on the resulting layout do not refer to the hardware coords (neither the offsets on the output refer to the actual offset) but they refer to equivalence classes.
In CS terms, this would be like when you have a float
ptr and doing increments of 1
with operator[]
moves you 4
bytes. Then a division would be like casting it to float4
(grouping a tile of 4 together). Then, the result of the division tells you how to move around in the float4
world, where you have collapsed 4 floats into one element. In particular, in this world, moving 1 moves you 16
bytes.
This is what we are doing but with a more complex tile, where moving 1
over a thread moves us 8
offsets (i.e. one full tile, as per the line regBase = b.shl(regBase, b.i32_val(tile.getTotalOutDimSizeLog2()));
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, I implemented it. Now the codegen should be better (on our end at least) and hopefully the explanation for why we create the layout should be clearer. WDYT
We use `divideLeft` to lower a generic `local_store` using `stmatrix` whenever possible. The current codegen for `stmatrix` has some limitations, so we just return early in those cases. We'll fix those in a future PR.
Generalise lowering by accepting permuted register layouts Simplify the lowering by using ColumnAction.
if (!sharedLayout) | ||
|
||
// Inter block stmatrix is not supported | ||
if (cvt.hasInDim(kBlock)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @peterbell10
There was a test in test_tensor_descriptor.py
that was not passing with 2CTAs. Now with this line passes. This line says that we just bail out if there's any interCTA business going on. It seems to work just fine as long as you are addressing things within your own CTA, even without any map or anything.
This is ready for review:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My remaining comments are non-blocking, we can chat offline.
We use
divideLeft
to lower a genericlocal_store
usingstmatrix
whenever possible.
We implement
ColumnAction
as a helper that allows us to permutethe bases and values and remove broadcasting in a generic way.
The current codegen for
stmatrix
has some limitations, so we justreturn early in those cases. We'll fix those in a future PR.