-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[OP] Add rms_norm
into TOPI
#15326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OP] Add rms_norm
into TOPI
#15326
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the nice work! A few minor comments.
|
||
# only test on llvm because schedule is missing | ||
@tvm.testing.parametrize_targets("llvm") | ||
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it also support symbolic shape, would you add a testcase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the symbolic shape test has been added.
include/tvm/topi/nn/rms_norm.h
Outdated
<< "rms_norm: only support float32 and float16 for now"; | ||
bool is_float16 = data_type == DataType::Float(16); | ||
|
||
auto x = is_float16 ? cast(data, DataType::Float(32)) : data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need casting here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, and all the cast
have been removed.
This PR introduces the operator root mean square, `rms_norm`, into TOPI.
This PR introduces the operator root mean square, `rms_norm`, into TOPI.
This PR introduces the operator root mean square, `rms_norm`, into TOPI.
This PR introduces the operator root mean square, `rms_norm`, into TOPI.
This PR introduces the operator root mean square,
rms_norm
, into TOPI.