Skip to content

Commit

Permalink
fused_attention_op paratmers stop grad support (PaddlePaddle#49351)
Browse files Browse the repository at this point in the history
* fusedAttenGrad_noGrad

* code style fix

* add ut

* remove unnecessary log
  • Loading branch information
wwbitejotunn authored Dec 29, 2022
1 parent 1c7ae95 commit 0bb999b
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 26 deletions.
67 changes: 45 additions & 22 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,31 +520,50 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
if (ctx->HasOutput(framework::GradVarName("OutLinearW"))) {
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
}
if (ctx->HasOutput(framework::GradVarName("QKVW"))) {
ctx->SetOutputDim(framework::GradVarName("QKVW"),
ctx->GetInputDim("QKVW"));
}
if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
}

if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
if (ctx->HasOutput(framework::GradVarName("LnOut"))) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
}
} else {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
if (ctx->HasOutput(framework::GradVarName("BiasDropoutResidualOut"))) {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
}
if (ctx->HasOutput(framework::GradVarName("FMHAOut"))) {
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
}
if (ctx->HasOutput(framework::GradVarName("QKTVOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
}
if (ctx->HasOutput(framework::GradVarName("TransposeOut2"))) {
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
}
if (ctx->HasOutput(framework::GradVarName("QKOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
}
if (ctx->HasOutput(framework::GradVarName("SoftmaxOut"))) {
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
}
if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) {
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
Expand All @@ -554,14 +573,18 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
if (ctx->HasOutput(framework::GradVarName("QKVOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
}
if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
if (ctx->HasOutput(framework::GradVarName("OutLinearOut"))) {
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
}
}

protected:
Expand Down
17 changes: 13 additions & 4 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));

auto *d_qkv_weight_data = dev_ctx.template Alloc<T>(
d_qkv_weight, d_qkv_weight->numel() * sizeof(T));
auto *d_qkv_weight_data =
(d_qkv_weight == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_qkv_weight,
d_qkv_weight->numel() * sizeof(T));

auto *d_qkv_bias_data =
(d_qkv_bias == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_qkv_bias,
d_qkv_bias->numel() * sizeof(T));
auto *d_out_linear_weight_data = dev_ctx.template Alloc<T>(
d_out_linear_weight, d_out_linear_weight->numel() * sizeof(T));
auto *d_out_linear_weight_data =
(d_out_linear_weight == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(
d_out_linear_weight,
d_out_linear_weight->numel() * sizeof(T));

auto *d_out_linear_bias_data =
(d_out_linear_bias == nullptr)
? nullptr
Expand Down
Loading

0 comments on commit 0bb999b

Please sign in to comment.