Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Custom schedule for standalone transpose in cuda #8030

Merged
merged 12 commits into from
May 20, 2021

Conversation

tkonolige
Copy link
Contributor

This PR adds an optimized schedule for transpose if the transpose is not fused into anything else.

@altanh @junrushao1994

Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor nits about code organization and the test, but otherwise looks good and is a welcome addition. For other reviewers: offline testing of this transpose schedule (when applicable) showed about a 2x improvement over the baseline injective transpose schedule on BERT Large training workloads. I do wonder how we could select whether or not to opt in to this vs trying to tune the original injective schedule, but I think in most cases this new kernel will probably be faster.

Dispatches to and optimized schedule if the transpose is standalone (not fused).
"""
warp_size = int(Target.current(allow_none=False).thread_warp_size)
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more principled way to do this? like maybe with an OpStrategy or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, there is not a better way to do this. There is a way to add implementations based on input sizes, but these are not on a per-target basis. If you know a better way, let me know.

@@ -105,13 +105,22 @@ def _callback(op):
return s


def schedule_cuda_transpose(s, out):
def schedule_transpose(outs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels a bit weird to have this in sparse.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to transform.py

r = np.random.rand(*shape)
tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r))

# make sure schedule does not fire here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a TODO? Also I wonder if it would be good to parametrize the test shape by warp size (rather than hard coding) for future proofing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more like a wish. Ideally we would be able to know which schedules were used, but there is to way to introspect on what was used. I've updated the comment to reflect this.

@@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt):
)

# filter out non-packed conv2d task
tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here, did this transpose change introduce a new task or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, no, but this check makes sure anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the new added schedule not tunable? Or is there any concern of adding knobs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to tune it in the future.

Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the only thing I'm wondering about is if someone (for whatever reason) really wanted to tune the default injective schedule for transpose, is there any way to allow that?

cc @comaniac for additional review (feel free to tag more relevant reviewers)

@comaniac
Copy link
Contributor

LGTM, the only thing I'm wondering about is if someone (for whatever reason) really wanted to tune the default injective schedule for transpose, is there any way to allow that?

cc @comaniac for additional review (feel free to tag more relevant reviewers)

There's no reason to tune inject schedule and you basically cannot do it because injective schedule doesn't have AutoTVM knobs for tuning.

Comment on lines +885 to +886
# We want to make sure schedule does not fire here, but there is no way of
# inspecting which schedules were used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this comment mentions, there is no way of inspecting which schedules were used, so it seems to me that the difference between this test and test_transpose is the workload in this test includes add to test the case of fusion. Accordingly, could we just extend test_transpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. I like to keep it separate so the intention is known.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Then it might be better to name it test_transpose_fuse or something like that (nit).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to test_transpose_unfused_schedule

@@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt):
)

# filter out non-packed conv2d task
tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the new added schedule not tunable? Or is there any concern of adding knobs?

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have other comments. LGTM.

@areusch areusch merged commit 28ea03c into apache:main May 20, 2021
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
* [TOPI] Custom schedule for standalone transpose in cuda

* check if input is not Any

* fix vta test

* check input shape

* fix injective

* move transpose out of sparse.py

* update comments, use warp size

* missspelled transform

* formatting

* rename test

* comment

* fix tests
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
* [TOPI] Custom schedule for standalone transpose in cuda

* check if input is not Any

* fix vta test

* check input shape

* fix injective

* move transpose out of sparse.py

* update comments, use warp size

* missspelled transform

* formatting

* rename test

* comment

* fix tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants