-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autotuner for int mm Triton kernels #41
Conversation
cpuhrsch
commented
Mar 3, 2024
•
edited
Loading
edited
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
About to head into a meeting but will give this a proper read, any chance we could add a test? |
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some nits
|
||
Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run. | ||
|
||
Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presumably people won't contribute the pickle file since that's not human readable? Also kind of a security issues for us to host pickle files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added https://github.com/pytorch-labs/ao/pull/41/files#diff-4986e1d3257adc0a73b17fd6f21ef9d3b2c0eaec9027381e6f2de89e5be0e6b5 to make it easier to inspect. It stores the triton Configs, so it's a bit more difficult to make them human readable by default.
|
||
|
||
|
||
def benchmark_in_ms(warmup, iters, f, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put this in benchmark util instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once I add the next benchmark for weight only
benchmarks/sam_shapes.csv
Outdated
@@ -0,0 +1,7 @@ | |||
m,k,n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presumably you mean shapes of matmuls in sam?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, SAM vit_b batch size 16 to be precise
test/kernel/test_autotuner.py
Outdated
[ | ||
("cuda", torch.bfloat16), | ||
("cuda", torch.bfloat16), | ||
# ("cpu", torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll turn those into TODOs. It should also work on CPU.
torchao/kernel/README.md
Outdated
@@ -0,0 +1,19 @@ | |||
## Autotuner and custom Triton kernels | |||
|
|||
### Use case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is intent to fill this out later? Might be better to open an issue if you'd like to do this later
|
||
:param fn: Function to benchmark | ||
:type fn: Callable | ||
:param warmup: Warmup time (in ms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warmup is not a time it seems to be an int so not in ms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't milliseconds be given in int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is just saying "run warmup until at least 25ms were spent"
torchao/kernel/autotuner.py
Outdated
|
||
BEST_CONFIGS = None | ||
|
||
AUTOTUNER_DATA_PATH = os.getenv('TORCHAO_AUTOTUNER_DATA_PATH', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put all global variables at the top of the file so they're easier to find
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point
int8_powers_of_two, | ||
int8_powers_of_two)], []) | ||
|
||
# int8_mm_kernel_configs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to leave these as reference from core. I'll add a comment.
import triton.language as tl | ||
import itertools | ||
import os | ||
int8_powers_of_two = [32, 64, 128, 256] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you envision people wanting to add more options here and for int8 kernel configs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, yes. Follow up here includes making it more extensible for other kernels. Adding support for mixed precision should help that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unblocking for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unblocking for now
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |