-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CK_TILE] Tune fmha fwd splitkv codgen #110
[CK_TILE] Tune fmha fwd splitkv codgen #110
Conversation
{ | ||
int device; | ||
auto status = hipGetDevice(&device); | ||
if(status != hipSuccess) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
follow FA coding style
if(status != hipSuccess) { return num_splits; }
|
||
hipDeviceProp_t props{}; | ||
status = hipGetDeviceProperties(&props, device); | ||
if(status != hipSuccess) | ||
{ | ||
return num_splits; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(status != hipSuccess) { return num_splits; }
// get kM0 for prefill phase | ||
if(is_prefill) | ||
{ | ||
return 128; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
}; | ||
|
||
for(auto [hdim, m0] : hdim_to_m0) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
{ | ||
if(hdim_q <= hdim && hdim_v <= hdim) | ||
{ | ||
return m0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
|
||
if(num_splits < 1 && p_drop == 0.0f) | ||
return num_splits_heuristic_ck( | ||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } | ||
max_splits = std::min({max_splits, num_SMs, num_n_blocks}); | ||
if(batch_nhead_mblocks >= 0.8f * num_SMs) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
std::array<float, num_splits_array.size()> efficiency; | ||
|
||
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
float eff = n_blocks / std::ceil(n_blocks); | ||
|
||
if(eff > max_efficiency) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_seqlen_q
// printf("num_splits chosen = %d\n", num_splits); | ||
return num_splits; | ||
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coding style
c44f7a7
to
ddcc375
Compare
see CK PR for more info.