-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Snippets] SplitDimensionM: heuristic update #28180
base: master
Are you sure you want to change the base?
[Snippets] SplitDimensionM: heuristic update #28180
Conversation
205e16c
to
78a43ee
Compare
@a-sidorova @IvanNovoselov could you please take a look? Thanks in advance |
if (divisor >= min_kernel_m) | ||
return std::make_pair(m_dim / divisor, divisor); | ||
const size_t m_kernel = m_dim / divisor; | ||
if (m_kernel >= min_kernel_m) { | ||
best_result.first = divisor; | ||
best_result.second = m_kernel; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear from this code why we try to find maximal divisor? If divisor >= min_kernel_m
we return, but if m_kernel >= min_kernel_m
we continue. Why?
It seems to me that this problem is symmetrical and should be treated accordingly:
if (m_dim % divisor == 0)
=> we can split M, great. There are 2 ways to do that:
a. (m_dim / divisor, divisor)
b. (divisor, m_dim / divisor)- These 2 ways are absolutely identical except for that this always holds:
divisor < (m_dim / divisor)
- It means that one way to split optimally is to start from the max divisor (=sqrt(m_dim)), go downward and return as soon as the parallel work is sufficient:
if(batch_dim * m_dim/divisor >= optimal_parallelism_work_amount)
. This way we'll make sure that both the kernel WA is maximal (since we're going downwards) and the parallel WA is optimal. - Alternatively, we can start from the minimal acceptable divisor, which should be min_kernel_m, go upward and return as soon as (m_dim % divisor == 0). This way we'll guarantee that the kernel work amount is larger than the minimal one (since we started from min_kernel_m) and the parallel work amount is maximal (since
m_dim / divisor
is deceasing). But that's not the case in this particular function, since we want to maximize the kernel WA. - Sometimes these
min_kernel_m
andoptimal_parallelism_work_amount
can be mutually exclusive, so we should think carefully which is more important. I guess that the parallel work amount should be prioritized, so the approach from 3 should be used. - It looks like we try to implement a mix of the above strategies here: we inspect both a and b splits, and return a if the min_kernel WA is achieved and b if parallel WA is satisfied. This may be not consistent in some circumstances, especially when the parallel & kernel WA limitations can't be fulfilled simultaneously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main thing I want to point out is that this heuristic maximizes m_batch
(= minimizes m_kernel
) + has a limitation m_kernel >= min_kernel_m
. In other words, we try to find m_kernel
bigger than min_kernel_m
and at the same time as close as possible to this value.
So batch_dim * m_dim/divisor >= optimal_parallelism_work_amount
is not sufficient criteria. Moreover, we can start this heuristic when batch_dim
is already bigger than optimal_parallelism_work_amount
. For the motivation of this strategy, please refer to the function's description:
Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration.
My logic is structured in the following way (taking into account that divisor is ascending) for the splitting candidates:
- if divisor is more than
min_kernel_m
, (a) strategy is used, and we can guarantee that this is the most optimal implementation fromm_kernel
minimization perspective. - if divisor is less than
min_kernel_m
, (b) strategy is used. But it is not guaranteed that the currentm_kernel = m_dim / divisor
is minimized: one of the next divisor from(divisor, sqrt(m_dim))
interval can be more optimal.
Alternatively, I can implement the same logic via 2 for
s with different traversal directions:
for (size_t divisor = min_kernel_m; divisor < std::sqrt(m_dim); ++divisor)
for (size_t divisor = min_kernel_m - 1; divisor > 1; --divisor)
Sometimes these
min_kernel_m
andoptimal_parallelism_work_amount
can be mutually exclusive, so we should think carefully which is more important. I guess that the parallel work amount should be prioritized, so the approach from 3 should be used.
This is true. But the current heuristic covers the most important cases (big shapes in SD topology), at least on the machines where these changes were tested. And we agreed offline that we need to limit these changes' impact on other topologies.
Anyway, if I have time on validation, I will try to further tune heuristic to cover the described situation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see, thanks for the detailed explanation 🙂
It's probably a good idea to add a corresponding comment, it might be handy if we'll need to update this heuristic.
static std::pair<size_t, size_t> compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); | ||
/** | ||
* @brief Aggressively splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration. | ||
*/ | ||
static std::pair<size_t, size_t> compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); | ||
/** | ||
* @brief Conservatively splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval | ||
*/ | ||
static std::pair<size_t, size_t> compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments with respect to the methods' names:
- the functions try to split M dimension into two: batch_m and new_m, so it seems logical to call them smth like
get_splitted_*
orsplit_m_*
orcompute_m_split
etc. Heuristic is used inside this functions to perform the split, so the functions do not compute heuristic, but rather use it to compute smth else (split). - We should try to use more descriptive names instead of
aggressive
orconservative
, because what seems conservative to one may seem aggressive to the other. Possible options include:split_ideally
,split_max_kernel_work
andsplit_max_parallel_work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to apply your suggestions. The only left thing is split_conservatively_increase_parallel_wa
: "conservatively" is still ambiguous, but I have no idea what adverb reflects the function's description better, so if you have any ideas please let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably split_fallback_increase_parallel_wa
, since it's used as a last resort to increase parallel wa? Or split_default_...
also an option here
return splited; | ||
} | ||
|
||
std::pair<size_t, size_t> SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { | ||
constexpr size_t min_kernel_m = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In second case in compute_ideal_cases_heuristic
min_kernel_m
is 64 while this is 32 here.
What's about to use always 64 and set as const static attribute of the class?
Or is there difference between heuristics and we really need to have smaller min_kernel_m
in aggressive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it's better to have one min_kernel_m
value. But I think that it should be 32, not 64. 64 value was set to empirically avoid the cases in which external repacking feature doesn't work, and overheads on repacking duplication inside kernel are bigger than benefits from the splitting. If external repacking works (and it seems like it will work in all cases after tokenization adjustments), we can easily lower min_kernel_m
for compute_ideal_cases_heuristic
// If M dim is big enough, aggressive heuristic is used for kernel_m minimization. | ||
// For smaller M dim, conservative heuristic is used to preserve old behavour. | ||
const bool big_m_dim = m_dim >= 4000; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, we also can support the case with small M
.
If batch < optimal_parallelism_work_amount
and M
is quite small (for example, M
< 64), nothing is needed to be updated and splitted - let's execute as it is.
I don't insist to do it in this PR but I have some models (for example, action-recognition or levit) with small values of batch and M
where this pass is applied and there will be M = 4
or even M = 1
. And this action leads to perf degrdation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like your idea. In third-party brgemm heuristics, I saw that minimal allowed m_kernel
is 16. Probably we can take this into account in our heuristics.
But it's also important is that SplitDimensionM::split
is used in CPU callback (via can_be_optimized
), so if it returns false, the MHA tokenization doesn't happen. So another question that we need to answer is whether we need to even tokenize such MHA's or not
78a43ee
to
4242663
Compare
Details:
Tickets: