Skip to content

Commit

Permalink
Add fast_ln spmd rules (#9125)
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 authored Sep 13, 2024
1 parent db270d9 commit cd3dc95
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions legacy/model_zoo/gpt-3/external_ops/fast_ln/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
* with minor changes. */

#include "paddle/extension.h"

#include "ln.h" // NOLINT

#ifdef CUSTOM_OP_WITH_SPMD
#include "paddle/phi/api/ext/spmd_infer.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

/*
Supported Type combinations:
Expand Down Expand Up @@ -197,12 +201,10 @@ std::vector<paddle::Tensor> LnFwd(const paddle::Tensor &x,
auto sizes = x.shape();
PD_CHECK(sizes.size() >= 2);

int rows = 1;
for (size_t i = 0; i + 1 < sizes.size(); ++i) {
rows *= sizes[i];
}

std::vector<int> row_sizes(sizes.begin(), sizes.begin() + sizes.size() - 1);

const int cols = sizes[sizes.size() - 1];
const int rows = x.numel() / cols;
auto hidden_size = scale.numel();

PD_CHECK(scale.shape() == bias.shape());
Expand All @@ -214,8 +216,8 @@ std::vector<paddle::Tensor> LnFwd(const paddle::Tensor &x,

auto y = paddle::empty(sizes, output_type, place);

auto mean = paddle::empty({rows}, compute_type, place);
auto invvar = paddle::empty({rows}, compute_type, place);
auto mean = paddle::empty({row_sizes}, compute_type, place);
auto invvar = paddle::empty({row_sizes}, compute_type, place);

LaunchNormFwd(x.stream(),
place,
Expand Down Expand Up @@ -481,11 +483,8 @@ std::vector<std::vector<int64_t>> LnFwdInferShape(
std::vector<int64_t> scale_shape,
std::vector<int64_t> bias_shape,
float epsilon) {
int64_t rows = 1;
for (size_t i = 0; i + 1 < x_shape.size(); ++i) {
rows *= x_shape[i];
}
return {x_shape, {rows}, {rows}};
std::vector<int64_t> row_shape(x_shape.begin(), x_shape.begin() + x_shape.size() - 1);
return {x_shape, row_shape, row_shape};
}

std::vector<std::vector<int64_t>> RMSLnFwdInferShape(
Expand Down Expand Up @@ -543,15 +542,23 @@ PD_BUILD_OP(fast_ln)
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(LnFwd))
.SetInferShapeFn(PD_INFER_SHAPE(LnFwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnInferSpmd))
#endif
;

PD_BUILD_GRAD_OP(fast_ln)
.Inputs({"x", "scale", "mean", "invvar", paddle::Grad("y")})
.Outputs({paddle::Grad("x"), paddle::Grad("scale"), paddle::Grad("bias")})
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(LnBwd))
.SetInferShapeFn(PD_INFER_SHAPE(LnBwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(LnBwdInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(LnBwdInferDtype))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnGradInferSpmd))
#endif
;

PD_BUILD_OP(fast_rms_norm)
.Inputs({"x", "scale"})
Expand Down

0 comments on commit cd3dc95

Please sign in to comment.