Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add Gradient Based Task Scheduler #10366

Merged

Conversation

zxybazh
Copy link
Member

@zxybazh zxybazh commented Feb 24, 2022

This PR is further improvement of the meta schedule project (#8473).

This PR introduces a gradient-based task scheduler which favors tuning task with more potential and calculates the gradient towards maximizing the final efficiency given computing budget. The gradients are computed in similar way as in auto scheduler. Unittest is also included.

@zxybazh
Copy link
Member Author

zxybazh commented Feb 24, 2022

CC @junrushao1994 @comaniac

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

python/tvm/meta_schedule/task_scheduler/gradient_based.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/task_scheduler/gradient_based.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/task_scheduler/gradient_based.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/task_scheduler/gradient_based.py Outdated Show resolved Hide resolved
src/meta_schedule/task_scheduler/gradient_based.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zxybazh
Copy link
Member Author

zxybazh commented Mar 3, 2022

@junrushao1994 Would you please take another look and get it merged some time soon?

@zxybazh zxybazh force-pushed the feature/2022-02-23/gradient-task-scheduler branch from 01c22c3 to e52b66f Compare March 4, 2022 05:55
@junrushao
Copy link
Member

Will review next week!

src/meta_schedule/utils.h Outdated Show resolved Hide resolved
@masahi
Copy link
Member

masahi commented Mar 24, 2022

Can we merge this? @zxybazh @junrushao1994

@junrushao
Copy link
Member

@masahi thanks for asking! We realized that there is an important problem not addressed yet in this PR (task weight prioritization). With other tasks of higher priority, we would love to delay the PR to next month.

@junrushao junrushao force-pushed the feature/2022-02-23/gradient-task-scheduler branch 3 times, most recently from f698b4e to 7723e4e Compare March 30, 2022 20:34
@zxybazh
Copy link
Member Author

zxybazh commented Mar 30, 2022

Hey Junru, the changes of API looks reasonable to me. Would you please retrigger the CI?

@junrushao junrushao force-pushed the feature/2022-02-23/gradient-task-scheduler branch 3 times, most recently from 3052796 to 540fac7 Compare March 31, 2022 06:52
@junrushao junrushao force-pushed the feature/2022-02-23/gradient-task-scheduler branch from 540fac7 to e3fbb79 Compare March 31, 2022 08:35
@Hzfengsy Hzfengsy merged commit 5629f8a into apache:main Mar 31, 2022
junrushao added a commit to junrushao/tvm that referenced this pull request Mar 31, 2022
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
@@ -98,14 +97,14 @@ def test_tune_matmul_cuda_tensor_core():
target = Target("nvidia/geforce-rtx-3070")
config = ReplayTraceConfig(
num_trials_per_iter=32,
num_trials_total=320,
max_trials_per_task=320,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zxybazh @junrushao1994 All uses of ReplayTraceConfig in tests are broken since now it has three attributes https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L52-L57

But they are not detected in CI since the tests are skipped.

Copy link
Member

@masahi masahi Apr 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something weird is going on.

First, max_trials_global is not used at https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L60

And if I use

            config=ReplayTraceConfig(
                num_trials_per_iter=32,
                max_trials_per_task=32,
                ...

, max_trials_per_task acts like a global max trials, so only one task is tuned, as shown below. This contradicts the name max_trials_per_task and does something very different from the previous num_trials_total.

 ID |                                                                        Name |      FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated 
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  0 |                                                         fused_nn_conv2d_add |  12870144 |      1 |        51.3887 |     250.4468 |              250.4468 |     32 |            
  1 |                                                       fused_nn_conv2d_add_1 |  12895232 |      1 |            N/A |          N/A |                   N/A |      0 |            
  2 |                                                       fused_nn_conv2d_add_2 |  12945408 |      1 |            N/A |          N/A |                   N/A |      0 |            
  3 |                                                      fused_layout_transform |         1 |      1 |            N/A |          N/A |                   N/A |      0 |            
  4 |                                                 fused_nn_conv2d_add_nn_relu | 237633536 |      1 |            N/A |          N/A |                   N/A |      0 |            
  5 |                                                         fused_nn_max_pool2d |   1806336 |      1 |            N/A |          N/A |                   N/A |      0 |            
  6 |                                               fused_nn_conv2d_add_nn_relu_1 | 231612416 |      2 |            N/A |          N/A |                   N/A |      0 |            
  7 |                                             fused_nn_conv2d_add_add_nn_relu | 231813120 |      2 |            N/A |          N/A |                   N/A |      0 |            
  8 |                                               fused_nn_conv2d_add_nn_relu_2 | 115806208 |      1 |            N/A |          N/A |                   N/A |      0 |            
  9 |       fused_nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu |  93227008 |      1 |            N/A |          N/A |                   N/A |      0 |            
 10 |   fused_nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu |  93327360 |      2 |            N/A |          N/A |                   N/A |      0 |            
 11 |                                               fused_nn_conv2d_add_nn_relu_3 | 115705856 |      1 |            N/A |          N/A |                   N/A |      0 |            
 12 |     fused_nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1 |  98600960 |      1 |            N/A |          N/A |                   N/A |      0 |            
 13 | fused_nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1 |  98651136 |      2 |            N/A |          N/A |                   N/A |      0 |            
 14 |                                               fused_nn_conv2d_add_nn_relu_4 | 115655680 |      1 |            N/A |          N/A |                   N/A |      0 |            
 15 |                                               fused_nn_conv2d_add_nn_relu_5 | 231261184 |      1 |            N/A |          N/A |                   N/A |      0 |            
 16 |                                           fused_nn_conv2d_add_add_nn_relu_1 | 231286272 |      2 |            N/A |          N/A |                   N/A |      0 |            
 17 |                                                fused_nn_adaptive_avg_pool2d |     25600 |      1 |            N/A |          N/A |                   N/A |      0 |            
 18 |                                      fused_layout_transform_reshape_squeeze |         1 |      1 |            N/A |          N/A |                   N/A |      0 |            
 19 |                                                          fused_nn_dense_add |   1025000 |      1 |            N/A |          N/A |                   N/A |      0 |            
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @masahi, thanks for detecting this issue swiftly! I apologize for the inconvenience as the end API is still in flux.

Yes, this PR breaks the existing unstable MetaSchedule API, and we want to do this from very early on, so that the APIs won't be broken again after formally announcing MetaSchedule is ready for production use.

The intention of having the parameter config is to provide an all-in-one place for end users to better configure the search - even if it's less realistic given builder/runner/database/cost_model/task_scheduler actually all need to be configured separately, it's intended to be the most frequent parameter users need to tweak.

Then let's discuss about max_trials_per_task and max_trials_global. Imagine a service that allows users to set how many a global uplimit of total number of trials for an entire model, as well as a per-task limit for each individual task extracted, then these are the two parameters to tweak. This PR adds this feature essentially for completeness of such potential SaaS feature request.

I have to admit that this breaking change to unannounced APIs could be damaging particularly to early users like you for sudden surprises; On the other hand, we do want to work together, do the correct thing to make those APIs look right to the end users. In the near future, we will likely introduce breaking changes including:

  • Based on your work, refactoring integration into extract_task and apply_history_best
  • More structured and readable logging
  • Less error-prone interface in tune.py (needs to collect more feedbacks)

We will do our best to ping early users like you to make sure people are aware of our breaking change before officially announcing MetaSchedule is product-ready. After being product ready, we will be less aggressive in terms of refactoring. Thank you so much for your understanding!

First, max_trials_global is not used at https://github.com/zxybazh/tvm/blob/e3fbb797a88308e4ce3d671939a83084ae1826b8/python/tvm/meta_schedule/search_strategy/replay_trace.py#L60

Yes, and this parameter is only used for configuring the task scheduler.

And if I use

            config=ReplayTraceConfig(
                num_trials_per_iter=32,
                max_trials_per_task=32,
                ...

, max_trials_per_task acts like a global max trials, so only one task is tuned, as shown below. This contradicts the name max_trials_per_task and does something very different from the previous num_trials_total.

Indeed this behavior doesn't make any sense. May I ask what your max_trials_global is in this case? It could lead to early stopping if it's less than 32 * num_tasks.

All uses of ReplayTraceConfig in tests are broken since now it has three attributes.

We updated the usage of this API in the following files:

  • python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
  • python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
  • tests/python/unittest/test_meta_schedule_measure_callback.py
  • tests/python/unittest/test_meta_schedule_search_strategy.py
  • tests/python/unittest/test_meta_schedule_task_scheduler.py
  • tests/python/unittest/test_meta_schedule_tune_relay.py
  • tests/python/unittest/test_meta_schedule_tune_te.py
  • tests/python/unittest/test_meta_schedule_tune_tir.py

And indeed we missed the following two files (essentially skipped ones):

  • tests/python/unittest/test_meta_schedule_tune_tir.py
  • tests/python/unittest/test_meta_schedule_tune_te.py
  • tests/python/unittest/test_meta_schedule_tune_relay.py

Clearly we forgot to do so for the skipped files, and I will shoot a PR to get them fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We updated the usage of this API in the following files:

python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
tests/python/unittest/test_meta_schedule_measure_callback.py
tests/python/unittest/test_meta_schedule_search_strategy.py
tests/python/unittest/test_meta_schedule_task_scheduler.py
tests/python/unittest/test_meta_schedule_tune_relay.py
tests/python/unittest/test_meta_schedule_tune_te.py
tests/python/unittest/test_meta_schedule_tune_tir.py

I don't see any updated usage... can you double check? Actually ReplayTraceConfig seems to be used only in three integration test files.

But yeah, I didn't notice that max_trials_global is used outside of ReplayTraceConfig class as config.max_trials_global, so I thought its value doesn't matter. I'll revisit this issue after seeing your fix to the integration tests.

I understand the risk of being an early user, but I'd like to see at least reasonable documentation of the API at all times, regardless of being stable / unstable. Especially when you introduce a breaking change. In this particular case, the new attribute max_trials_global is not documented and there is no example of how to use this value, as far as I can see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the risk of being an early user, but I'd like to see at least reasonable documentation of the API at all times, regardless of being stable / unstable. Especially when you introduce a breaking change. In this particular case, the new attribute max_trials_global is not documented and there is no example of how to use this value, as far as I can see.

That's a reasonable ask. I'm drafting a bugfix PR which will in the mean time make sure the document is clear. I'm sorry that it could potentially miss some corner cases during our death march last week on performance alignment in BERT.

A further step, I would love to have a user-facing API audit with you and the rest of the community, so that we could build broader consensus of what things should look like.

pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 11, 2022
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants