-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Move norm to pten #39324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move norm to pten #39324
Conversation
Thanks for your contribution! |
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/eigen/eigen_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在pten下的eigen应该可以用了
// limitations under the License. | ||
|
||
#include "paddle/pten/kernels/norm_kernel.h" | ||
#include "paddle/fluid/operators/eigen/eigen_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
namespace cub = hipcub; | ||
#endif | ||
#include "paddle/fluid/operators/amp/fp16_type_traits.h" | ||
#include "paddle/fluid/platform/bfloat16.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfloat16也迁到pten了
namespace cub = hipcub; | ||
#endif | ||
#include "paddle/fluid/operators/amp/fp16_type_traits.h" | ||
#include "paddle/fluid/platform/bfloat16.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
paddle/fluid/operators/norm_op.cc
Outdated
@@ -72,6 +72,12 @@ class NormOp : public framework::OperatorWithKernel { | |||
ctx->SetOutputDim("Norm", xdim); | |||
} | |||
} | |||
|
|||
framework::KernelSignature GetExpectedPtenKernelArgs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前向的这个映射函数建议就在pten/ops/compat目录下维护吧
#include "paddle/pten/core/ddim.h" | ||
|
||
namespace pten { | ||
inline void GetDims( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
impl目录下定位是放置CPU和GPU代码公用的kernel实现,如果只是辅助函数,建议将函数放到pten/funcs,比如可以放到pten/funcs/common_shape.h中
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是需要include一下norm_grad_kernel.h
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议include一下norm_kernel.h
@@ -154,4 +154,6 @@ def test_norm_x_type(): | |||
|
|||
|
|||
if __name__ == '__main__': | |||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle应该可以移到文件头部
… move_norm_to_pten
… move_norm_to_pten
… move_norm_to_pten
… move_norm_to_pten
… move_norm_to_pten
… move_norm_to_pten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Breaking changes
PR changes
OPs
Describe
move norm to pten