-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MetaSchedule] Add Gradient Based Task Scheduler #10366
[MetaSchedule] Add Gradient Based Task Scheduler #10366
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@junrushao1994 Would you please take another look and get it merged some time soon? |
01c22c3
to
e52b66f
Compare
Will review next week! |
Can we merge this? @zxybazh @junrushao1994 |
@masahi thanks for asking! We realized that there is an important problem not addressed yet in this PR (task weight prioritization). With other tasks of higher priority, we would love to delay the PR to next month. |
f698b4e
to
7723e4e
Compare
Hey Junru, the changes of API looks reasonable to me. Would you please retrigger the CI? |
3052796
to
540fac7
Compare
540fac7
to
e3fbb79
Compare
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
@@ -98,14 +97,14 @@ def test_tune_matmul_cuda_tensor_core(): | |||
target = Target("nvidia/geforce-rtx-3070") | |||
config = ReplayTraceConfig( | |||
num_trials_per_iter=32, | |||
num_trials_total=320, | |||
max_trials_per_task=320, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zxybazh @junrushao1994 All uses of ReplayTraceConfig
in tests are broken since now it has three attributes https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L52-L57
But they are not detected in CI since the tests are skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something weird is going on.
First, max_trials_global
is not used at https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L60
And if I use
config=ReplayTraceConfig(
num_trials_per_iter=32,
max_trials_per_task=32,
...
, max_trials_per_task
acts like a global max trials, so only one task is tuned, as shown below. This contradicts the name max_trials_per_task
and does something very different from the previous num_trials_total
.
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_nn_conv2d_add | 12870144 | 1 | 51.3887 | 250.4468 | 250.4468 | 32 |
1 | fused_nn_conv2d_add_1 | 12895232 | 1 | N/A | N/A | N/A | 0 |
2 | fused_nn_conv2d_add_2 | 12945408 | 1 | N/A | N/A | N/A | 0 |
3 | fused_layout_transform | 1 | 1 | N/A | N/A | N/A | 0 |
4 | fused_nn_conv2d_add_nn_relu | 237633536 | 1 | N/A | N/A | N/A | 0 |
5 | fused_nn_max_pool2d | 1806336 | 1 | N/A | N/A | N/A | 0 |
6 | fused_nn_conv2d_add_nn_relu_1 | 231612416 | 2 | N/A | N/A | N/A | 0 |
7 | fused_nn_conv2d_add_add_nn_relu | 231813120 | 2 | N/A | N/A | N/A | 0 |
8 | fused_nn_conv2d_add_nn_relu_2 | 115806208 | 1 | N/A | N/A | N/A | 0 |
9 | fused_nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu | 93227008 | 1 | N/A | N/A | N/A | 0 |
10 | fused_nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu | 93327360 | 2 | N/A | N/A | N/A | 0 |
11 | fused_nn_conv2d_add_nn_relu_3 | 115705856 | 1 | N/A | N/A | N/A | 0 |
12 | fused_nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1 | 98600960 | 1 | N/A | N/A | N/A | 0 |
13 | fused_nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1 | 98651136 | 2 | N/A | N/A | N/A | 0 |
14 | fused_nn_conv2d_add_nn_relu_4 | 115655680 | 1 | N/A | N/A | N/A | 0 |
15 | fused_nn_conv2d_add_nn_relu_5 | 231261184 | 1 | N/A | N/A | N/A | 0 |
16 | fused_nn_conv2d_add_add_nn_relu_1 | 231286272 | 2 | N/A | N/A | N/A | 0 |
17 | fused_nn_adaptive_avg_pool2d | 25600 | 1 | N/A | N/A | N/A | 0 |
18 | fused_layout_transform_reshape_squeeze | 1 | 1 | N/A | N/A | N/A | 0 |
19 | fused_nn_dense_add | 1025000 | 1 | N/A | N/A | N/A | 0 |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @masahi, thanks for detecting this issue swiftly! I apologize for the inconvenience as the end API is still in flux.
Yes, this PR breaks the existing unstable MetaSchedule API, and we want to do this from very early on, so that the APIs won't be broken again after formally announcing MetaSchedule is ready for production use.
The intention of having the parameter config
is to provide an all-in-one place for end users to better configure the search - even if it's less realistic given builder/runner/database/cost_model/task_scheduler actually all need to be configured separately, it's intended to be the most frequent parameter users need to tweak.
Then let's discuss about max_trials_per_task
and max_trials_global
. Imagine a service that allows users to set how many a global uplimit of total number of trials for an entire model, as well as a per-task limit for each individual task extracted, then these are the two parameters to tweak. This PR adds this feature essentially for completeness of such potential SaaS feature request.
I have to admit that this breaking change to unannounced APIs could be damaging particularly to early users like you for sudden surprises; On the other hand, we do want to work together, do the correct thing to make those APIs look right to the end users. In the near future, we will likely introduce breaking changes including:
- Based on your work, refactoring
integration
intoextract_task
andapply_history_best
- More structured and readable logging
- Less error-prone interface in
tune.py
(needs to collect more feedbacks)
We will do our best to ping early users like you to make sure people are aware of our breaking change before officially announcing MetaSchedule is product-ready. After being product ready, we will be less aggressive in terms of refactoring. Thank you so much for your understanding!
First, max_trials_global is not used at https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L60
Yes, and this parameter is only used for configuring the task scheduler.
And if I use
config=ReplayTraceConfig( num_trials_per_iter=32, max_trials_per_task=32, ...
,
max_trials_per_task
acts like a global max trials, so only one task is tuned, as shown below. This contradicts the namemax_trials_per_task
and does something very different from the previousnum_trials_total
.
Indeed this behavior doesn't make any sense. May I ask what your max_trials_global
is in this case? It could lead to early stopping if it's less than 32 * num_tasks
.
All uses of ReplayTraceConfig in tests are broken since now it has three attributes.
We updated the usage of this API in the following files:
- python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
- python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
- tests/python/unittest/test_meta_schedule_measure_callback.py
- tests/python/unittest/test_meta_schedule_search_strategy.py
- tests/python/unittest/test_meta_schedule_task_scheduler.py
- tests/python/unittest/test_meta_schedule_tune_relay.py
- tests/python/unittest/test_meta_schedule_tune_te.py
- tests/python/unittest/test_meta_schedule_tune_tir.py
And indeed we missed the following two files (essentially skipped ones):
- tests/python/unittest/test_meta_schedule_tune_tir.py
- tests/python/unittest/test_meta_schedule_tune_te.py
- tests/python/unittest/test_meta_schedule_tune_relay.py
Clearly we forgot to do so for the skipped files, and I will shoot a PR to get them fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We updated the usage of this API in the following files:
python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
tests/python/unittest/test_meta_schedule_measure_callback.py
tests/python/unittest/test_meta_schedule_search_strategy.py
tests/python/unittest/test_meta_schedule_task_scheduler.py
tests/python/unittest/test_meta_schedule_tune_relay.py
tests/python/unittest/test_meta_schedule_tune_te.py
tests/python/unittest/test_meta_schedule_tune_tir.py
I don't see any updated usage... can you double check? Actually ReplayTraceConfig
seems to be used only in three integration test files.
But yeah, I didn't notice that max_trials_global
is used outside of ReplayTraceConfig
class as config.max_trials_global
, so I thought its value doesn't matter. I'll revisit this issue after seeing your fix to the integration tests.
I understand the risk of being an early user, but I'd like to see at least reasonable documentation of the API at all times, regardless of being stable / unstable. Especially when you introduce a breaking change. In this particular case, the new attribute max_trials_global
is not documented and there is no example of how to use this value, as far as I can see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the risk of being an early user, but I'd like to see at least reasonable documentation of the API at all times, regardless of being stable / unstable. Especially when you introduce a breaking change. In this particular case, the new attribute
max_trials_global
is not documented and there is no example of how to use this value, as far as I can see.
That's a reasonable ask. I'm drafting a bugfix PR which will in the mean time make sure the document is clear. I'm sorry that it could potentially miss some corner cases during our death march last week on performance alignment in BERT.
A further step, I would love to have a user-facing API audit with you and the rest of the community, so that we could build broader consensus of what things should look like.
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
This PR is further improvement of the meta schedule project (#8473).
This PR introduces a gradient-based task scheduler which favors tuning task with more potential and calculates the gradient towards maximizing the final efficiency given computing budget. The gradients are computed in similar way as in auto scheduler. Unittest is also included.