-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Det &Slogdet #34992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Det &Slogdet #34992
Changes from all commits
b991e1f
548534f
c04f269
3fc793b
04e186c
9da88ba
ca90942
56845ed
224250d
8d75e54
c0ddcfe
d1f9222
e001a62
940d2f0
9d9fa4c
3e98afc
a7e51d5
760b5a5
ae11334
ceb2c43
200562a
20b5c87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/determinant_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class DeterminantOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); | ||
} | ||
}; | ||
|
||
class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Input", "(Tensor) The input tensor of determinant."); | ||
AddOutput("Out", | ||
"(Tensor) The output Tensor containing the determinant" | ||
"value of a square matrix or batches of square matrices "); | ||
|
||
AddComment(R"DOC( | ||
Determinant Operator.)DOC"); | ||
} | ||
}; | ||
|
||
class DeterminantGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", | ||
"DeterminantGradOp"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 反向还需要
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp"); | ||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", | ||
framework::GradVarName("Input"), "DeterminantGradOp"); | ||
|
||
ctx->SetOutputDim(framework::GradVarName("Input"), | ||
ctx->GetInputDim("Input")); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class DeterminantGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> grad_op) const override { | ||
grad_op->SetType("determinant_grad"); | ||
grad_op->SetInput("Input", this->Input("Input")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 反向还需要
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
grad_op->SetInput("Out", this->Output("Out")); | ||
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
grad_op->SetOutput(framework::GradVarName("Input"), | ||
this->InputGrad("Input")); | ||
grad_op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DeterminantGradNoNeedBufferVarsInferer, | ||
"Input"); | ||
|
||
class SlogDeterminantOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); | ||
} | ||
}; | ||
|
||
class SlogDeterminantOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Input", "(Tensor) The input tensor of SlogDeterminant."); | ||
AddOutput("Out", | ||
"(Tensor) The output tensor containing the sign of the" | ||
"determinant and the natural logarithm" | ||
"of the absolute value of determinant,"); | ||
|
||
AddComment(R"DOC( | ||
SlogDeterminant Operator.)DOC"); | ||
} | ||
}; | ||
|
||
class SlogDeterminantGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", | ||
"SlogDeterminantGradOp"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", | ||
"SlogDeterminantGradOp"); | ||
|
||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", | ||
framework::GradVarName("Input"), "SlogDeterminantGradOp"); | ||
|
||
ctx->SetOutputDim(framework::GradVarName("Input"), | ||
ctx->GetInputDim("Input")); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class SlogDeterminantGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> grad_op) const override { | ||
grad_op->SetType("slogdeterminant_grad"); | ||
grad_op->SetInput("Input", this->Input("Input")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
grad_op->SetInput("Out", this->Output("Out")); | ||
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
grad_op->SetOutput(framework::GradVarName("Input"), | ||
this->InputGrad("Input")); | ||
grad_op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, | ||
"Input"); | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, | ||
ops::DeterminantGradOpMaker<paddle::framework::OpDesc>, | ||
ops::DeterminantGradOpMaker<paddle::imperative::OpBase>); | ||
|
||
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp) | ||
|
||
REGISTER_OP_CPU_KERNEL(determinant, | ||
ops::DeterminantKernel<plat::CPUDeviceContext, float>, | ||
ops::DeterminantKernel<plat::CPUDeviceContext, double>); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
determinant_grad, ops::DeterminantGradKernel<plat::CPUDeviceContext, float>, | ||
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>); | ||
|
||
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, | ||
ops::SlogDeterminantOpMaker, | ||
ops::SlogDeterminantGradOpMaker<paddle::framework::OpDesc>, | ||
ops::SlogDeterminantGradOpMaker<paddle::imperative::OpBase>); | ||
|
||
REGISTER_OPERATOR(slogdeterminant_grad, | ||
ops::DeterminantGradOp) // reuse det grad op | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿应该是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里暂时复用 DeterminantGradOp. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
REGISTER_OP_CPU_KERNEL( | ||
slogdeterminant, ops::SlogDeterminantKernel<plat::CPUDeviceContext, float>, | ||
ops::SlogDeterminantKernel<plat::CPUDeviceContext, double>); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
slogdeterminant_grad, | ||
ops::DeterminantGradKernel<plat::CPUDeviceContext, float>, | ||
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/determinant_op.h" | ||
#include "paddle/fluid/platform/cuda_primitives.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using platform::PADDLE_CUDA_NUM_THREADS; | ||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
__global__ void DeterminantGrad(const size_t numel, T* out) { | ||
zhhsplendid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (tid < numel) { | ||
out[tid] = static_cast<T>(1); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个CUDA kernel可以用set_constant替代 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分后面会被删除。 |
||
|
||
template <typename T> | ||
class DeterminantGradCUDAKernel : public framework::OpKernel<T> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,这个kernel也可以删了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
const auto* dout = context.Input<Tensor>(framework::GradVarName("Out")); | ||
const T* dout_data = dout->data<T>(); | ||
auto dout_dim = vectorize(dout->dims()); | ||
|
||
auto* dx = context.Output<Tensor>(framework::GradVarName("Input")); | ||
T* dx_data = dx->mutable_data<T>(context.GetPlace()); | ||
|
||
int64_t numel = dx->numel(); | ||
for (int64_t idx = 0; idx < numel; idx++) { | ||
dx_data[idx] = static_cast<T>(1); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
REGISTER_OP_CUDA_KERNEL( | ||
determinant, ops::DeterminantKernel<plat::CUDADeviceContext, float>, | ||
ops::DeterminantKernel<plat::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
determinant_grad, | ||
ops::DeterminantGradKernel<plat::CUDADeviceContext, float>, | ||
ops::DeterminantGradKernel<plat::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>, | ||
ops::SlogDeterminantKernel<plat::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
slogdeterminant_grad, | ||
ops::SlogDeterminantGradKernel<plat::CUDADeviceContext, float>, | ||
ops::SlogDeterminantGradKernel<plat::CUDADeviceContext, double>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should do infer shape at compile time at least?