-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add isin op #188
Conversation
zhzhcookie
commented
Aug 28, 2024
- Add isin op.
- Add unit tests and performance tests of isin op.
- Optimize the implement of unique op, which is used by isin op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just wait for libentry
bug fixed and pull the master :)
3931767
to
1efa7ab
Compare
src/flag_gems/ops/isin.py
Outdated
BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) | ||
else: | ||
BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) | ||
num_ctas = min(65536, triton.cdiv(M, BLOCK_M)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this number is to be used as grid.y, then the current limit is 65535.
M = in0.numel() | ||
N = in1.numel() | ||
if M <= 1024: | ||
BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduction dim's tile size BLOCK_N is considerably small. Triton has pretty competitive reduction builtin. Is there room for larger reduction tiles?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduction dim's tile size BLOCK_N is considerably small. Triton has pretty competitive reduction builtin. Is there room for larger reduction tiles?
It is the autotune result.
# launch kernel func | ||
M = in0_ravel.numel() | ||
N = in1_ravel.numel() | ||
if M <= 1048576: # 2 ** 20 = 1024 * 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why BLOCK_N is dependent on M.
src/flag_gems/ops/isin.py
Outdated
start = tl.zeros_like(r) | ||
end = start + N | ||
while_mask = start < end | ||
while tl.sum(tl.where(while_mask, out, 1)) != BLOCK_M: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save an extra reduction at the end of each iteration, I would rather use a fixed loop bound log2(N) + 1 and make sure that start and end meet in the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save an extra reduction at the end of each iteration, I would rather use a fixed loop bound log2(N) + 1 and make sure that start and end meet in the end.
done
src/flag_gems/ops/isin.py
Outdated
invert: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
num_ctas = tl.num_programs(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_ctas is a Triton reserved key word. We'd better rename it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_ctas is a Triton reserved key word. We'd better rename it.
done
1. Add isin op. 2. Add unit tests and performance tests of isin op. 3. Optimize the implement of unique op, which is used by isin op.