Skip to content

Commit c04f269

Browse files
committed
Add API of slogdet
1 parent 548534f commit c04f269

File tree

5 files changed

+425
-26
lines changed

5 files changed

+425
-26
lines changed

CMakeLists.txt

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414

15-
cmake_minimum_required(VERSION 3.10)
16-
cmake_policy(VERSION 3.10)
15+
if(APPLE AND WITH_ARM)
16+
# cmake 3.19.2 version starts to support M1
17+
cmake_minimum_required(VERSION 3.19.2)
18+
cmake_policy(VERSION 3.19.2)
19+
else(APPLE AND WITH_ARM)
20+
cmake_minimum_required(VERSION 3.10)
21+
cmake_policy(VERSION 3.10)
22+
endif(APPLE AND WITH_ARM)
1723
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
1824
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
1925
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
@@ -73,6 +79,11 @@ if(WITH_MUSL)
7379
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy")
7480
endif()
7581

82+
if(APPLE AND WITH_ARM)
83+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -target arm64-apple-darwin")
84+
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -target arm64-apple-darwin")
85+
endif()
86+
7687
if(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
7788
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
7889
endif()
@@ -97,10 +108,6 @@ if(WIN32)
97108

98109
if (MSVC_STATIC_CRT)
99110
message(STATUS "Use static C runtime time, refer to https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=vs-2019")
100-
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /MTd")
101-
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /MT")
102-
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd")
103-
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
104111
foreach(flag_var
105112
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
106113
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
@@ -112,17 +119,19 @@ if(WIN32)
112119
endforeach(flag_var)
113120
endif()
114121

115-
math(EXPR PROCESS_MAX "${CPU_CORES} * 2 / 3")
116-
117122
# windows build turn off warnings, use parallel compiling.
118123
foreach(flag_var
119124
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
120125
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
121126
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
122127
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
123128
string(REGEX REPLACE "/W[1-4]" " /W0 " ${flag_var} "${${flag_var}}")
124-
# NOTE(zhouwei25): GPU compile have too high memory utilization when parallel compiling
125-
if(NOT WITH_GPU)
129+
130+
# NOTE(zhouwei25): GPU compile have too high memory utilization when parallel compiling,
131+
# For Visual Studio generators, /MP should be added.
132+
# For other generators like Ninja, it is not need to add /MP.
133+
if(CMAKE_GENERATOR MATCHES "Visual Studio" AND NOT WITH_GPU)
134+
math(EXPR PROCESS_MAX "${CPU_CORES} * 2 / 3")
126135
set(${flag_var} "${${flag_var}} /MP${PROCESS_MAX}")
127136
endif()
128137
endforeach(flag_var)
@@ -305,6 +314,17 @@ else()
305314
endif()
306315
endif()
307316

317+
if(WITH_DISTRIBUTE)
318+
if(LINUX)
319+
set(WITH_GLOO ON CACHE STRING "Enable GLOO when compiling WITH_DISTRIBUTE=ON." FORCE)
320+
endif()
321+
if(WITH_ASCEND_CL)
322+
# disable WITH_PSCORE for NPU before include third_party
323+
MESSAGE(WARNING "Disable WITH_PSCORE when compiling with NPU. Force WITH_PSCORE=OFF.")
324+
set(WITH_PSCORE OFF CACHE BOOL "Disable WITH_PSCORE when compiling with NPU" FORCE)
325+
endif()
326+
endif()
327+
308328
include(third_party) # download, build, install third_party, Contains about 20+ dependencies
309329

310330
include(flags) # set paddle compile flags
@@ -315,12 +335,6 @@ if(WITH_PROFILER)
315335
add_definitions(-DWITH_GPERFTOOLS)
316336
endif()
317337

318-
if(WITH_DISTRIBUTE)
319-
if(LINUX)
320-
set(WITH_GLOO ON CACHE STRING "Enable GLOO when compiling WITH_DISTRIBUTE=ON." FORCE)
321-
endif()
322-
endif()
323-
324338
include(ccache) # set ccache for compilation
325339
include(util) # set unittest and link libs
326340
include(version) # set PADDLE_VERSION
@@ -336,8 +350,9 @@ endif()
336350
if(WITH_ARM)
337351
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
338352
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
339-
set(WITH_XBYAK OFF CACHE STRING "Disable XBYAK when compiling WITH_ARM=ON" FORCE)
353+
set(WITH_XBYAK OFF CACHE STRING "Disable XBYAK when compiling WITH_ARM=ON." FORCE)
340354
set(WITH_MKL OFF CACHE STRING "Disable MKL when compiling WITH_ARM=ON." FORCE)
355+
set(WITH_AVX OFF CACHE STRING "Disable AVX when compiling WITH_AVX=OFF." FORCE)
341356
add_definitions(-DPADDLE_WITH_ARM)
342357
endif()
343358

paddle/fluid/operators/determinant_op.cc

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,73 @@ class DeterminantGradOpMaker : public framework::SingleGradOpMaker<T> {
8484
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DeterminantGradNoNeedBufferVarsInferer,
8585
"Input");
8686

87+
class SlogDeterminantOp : public framework::OperatorWithKernel {
88+
public:
89+
using framework::OperatorWithKernel::OperatorWithKernel;
90+
91+
void InferShape(framework::InferShapeContext *ctx) const override {
92+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant");
93+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant");
94+
}
95+
};
96+
97+
class SlogDeterminantOpMaker : public framework::OpProtoAndCheckerMaker {
98+
public:
99+
void Make() override {
100+
AddInput(
101+
"Input",
102+
"(Tensor) The input tensor, from which the determinant are taken.");
103+
AddOutput("Out",
104+
"(Tensor) The partial view of input with the its slogdeterminant "
105+
"elements.");
106+
107+
AddComment(R"DOC(
108+
SlogDeterminant Operator.)DOC");
109+
}
110+
};
111+
112+
class SlogDeterminantGradOp : public framework::OperatorWithKernel {
113+
public:
114+
using framework::OperatorWithKernel::OperatorWithKernel;
115+
116+
void InferShape(framework::InferShapeContext *ctx) const override {
117+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
118+
"SlogDeterminantGradOp");
119+
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
120+
framework::GradVarName("Input"), "SlogDeterminantGradOp");
121+
122+
ctx->SetOutputDim(framework::GradVarName("Input"),
123+
ctx->GetInputDim("Input"));
124+
}
125+
126+
protected:
127+
framework::OpKernelType GetExpectedKernelType(
128+
const framework::ExecutionContext &ctx) const override {
129+
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
130+
ctx, framework::GradVarName("Out")),
131+
ctx.GetPlace());
132+
}
133+
};
134+
135+
template <typename T>
136+
class SlogDeterminantGradOpMaker : public framework::SingleGradOpMaker<T> {
137+
public:
138+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
139+
140+
protected:
141+
void Apply(GradOpPtr<T> grad_op) const override {
142+
grad_op->SetType("slogdeterminant_grad");
143+
grad_op->SetInput("Input", this->Input("Input"));
144+
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
145+
grad_op->SetOutput(framework::GradVarName("Input"),
146+
this->InputGrad("Input"));
147+
grad_op->SetAttrMap(this->Attrs());
148+
}
149+
};
150+
151+
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer,
152+
"Input");
153+
87154
} // namespace operators
88155
} // namespace paddle
89156

@@ -97,3 +164,14 @@ REGISTER_OP_CPU_KERNEL(determinant, ops::DeterminantKernel<int>,
97164
ops::DeterminantKernel<float>,
98165
ops::DeterminantKernel<double>,
99166
ops::DeterminantKernel<bool>);
167+
168+
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
169+
ops::SlogDeterminantOpMaker,
170+
ops::SlogDeterminantGradOpMaker<paddle::framework::OpDesc>,
171+
ops::SlogDeterminantGradOpMaker<paddle::imperative::OpBase>);
172+
173+
REGISTER_OP_CPU_KERNEL(slogdeterminant, ops::SlogDeterminantKernel<int>,
174+
ops::SlogDeterminantKernel<int64_t>,
175+
ops::SlogDeterminantKernel<float>,
176+
ops::SlogDeterminantKernel<double>,
177+
ops::SlogDeterminantKernel<bool>);
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/determinant_op.h"
17+
#include "paddle/fluid/platform/cuda_primitives.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using platform::PADDLE_CUDA_NUM_THREADS;
23+
using Tensor = framework::Tensor;
24+
25+
template <typename T>
26+
__global__ void Determinant(const size_t numel, const T* in, int rank, T* out) {
27+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
28+
if (tid < numel) {
29+
Eigen::MatrixXf matrix(rank, rank);
30+
31+
for (int i = 0; i < rank; ++i) {
32+
for (int j = 0; j < rank; ++j) {
33+
matrix(i, j) = in[rank * i + j];
34+
}
35+
out[tid] = matrix.determinant();
36+
}
37+
}
38+
}
39+
40+
template <typename T>
41+
__global__ void DeterminantGrad(const size_t numel, T* out) {
42+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
43+
if (tid < numel) {
44+
out[tid] = static_cast<T>(1);
45+
}
46+
}
47+
template <typename T>
48+
class DeterminantCUDAKernel : public framework::OpKernel<T> {
49+
public:
50+
void Compute(const framework::ExecutionContext& context) const override {
51+
auto* input = context.Input<framework::Tensor>("Input");
52+
const auto* input_data = input->data<T>();
53+
auto input_dim = input->dims().Get();
54+
auto input_dim_size = input->dims().size();
55+
56+
std::vector<int64_t> res_in = vectorize(framework::stride(input->dims()));
57+
paddle::framework::Tensor input_stride_tensor;
58+
framework::TensorFromVector<int64_t>(res_in, context.device_context(),
59+
&input_stride_tensor);
60+
61+
auto* output = context.Output<framework::Tensor>("Out");
62+
auto* output_data = output->mutable_data<T>(context.GetPlace());
63+
auto output_dim = output->dims().Get();
64+
auto output_dim_size = output->dims().size();
65+
auto numel = output->numel();
66+
67+
int threads = PADDLE_CUDA_NUM_THREADS;
68+
int blocks = (numel + threads - 1) / threads;
69+
70+
auto rank = input_dim[input_dim_size - 1];
71+
Determinant<T><<<blocks, threads>>>(numel, input_data, rank, output_data);
72+
}
73+
};
74+
75+
template <typename T>
76+
class DeterminantGradCUDAKernel : public framework::OpKernel<T> {
77+
public:
78+
void Compute(const framework::ExecutionContext& context) const override {
79+
const auto* dout =
80+
context.Input<framework::Tensor>(framework::GradVarName("Out"));
81+
const T* dout_data = dout->data<T>();
82+
auto dout_dim = vectorize(dout->dims());
83+
84+
auto* dx =
85+
context.Output<framework::Tensor>(framework::GradVarName("Input"));
86+
T* dx_data = dx->mutable_data<T>(context.GetPlace());
87+
88+
int64_t numel = dx->numel();
89+
for (int64_t idx = 0; idx < numel; idx++) {
90+
dx_data[idx] = static_cast<T>(1);
91+
}
92+
}
93+
};
94+
95+
template <typename T>
96+
__global__ void SlogDeterminant(const size_t total, const T* in, int rank,
97+
T* out) {
98+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
99+
if (tid < total) {
100+
Eigen::MatrixXf matrix(rank, rank);
101+
102+
for (int i = 0; i < rank; ++i) {
103+
for (int j = 0; j < rank; ++j) {
104+
matrix(i, j) = ingit[rank * i + j];
105+
}
106+
out[tid] = sin(matrix.determinant());
107+
out[tid + total] = log(matrix.determinant());
108+
}
109+
}
110+
}
111+
112+
template <typename T>
113+
class SlogDeterminantCUDAKernel : public framework::OpKernel<T> {
114+
public:
115+
void Compute(const framework::ExecutionContext& context) const override {
116+
auto* input = context.Input<framework::Tensor>("Input");
117+
const auto* input_data = input->data<T>();
118+
auto input_dim = input->dims().Get();
119+
auto input_dim_size = input->dims().size();
120+
121+
std::vector<int64_t> res_in = vectorize(framework::stride(input->dims()));
122+
paddle::framework::Tensor input_stride_tensor;
123+
framework::TensorFromVector<int64_t>(res_in, context.device_context(),
124+
&input_stride_tensor);
125+
126+
auto* output = context.Output<framework::Tensor>("Out");
127+
auto* output_data = output->mutable_data<T>(context.GetPlace());
128+
auto output_dim = output->dims().Get();
129+
auto output_dim_size = output->dims().size();
130+
131+
int threads = PADDLE_CUDA_NUM_THREADS;
132+
auto numel = output->numel() / 2;
133+
int blocks = (numel + threads - 1) / threads;
134+
135+
auto rank = input_dim[input_dim_size - 1];
136+
SlogDeterminant<T><<<blocks, threads>>>(numel, input_data, rank,
137+
output_data);
138+
}
139+
};
140+
141+
template <typename T>
142+
class SlogDeterminantGradCUDAKernel : public framework::OpKernel<T> {
143+
public:
144+
void Compute(const framework::ExecutionContext& context) const override {
145+
auto* input = context.Input<framework::Tensor>("Input");
146+
const auto* input_data = input->data<T>();
147+
auto input_dim = input->dims().Get();
148+
auto input_dim_size = input->dims().size();
149+
150+
std::vector<int64_t> res_in = vectorize(framework::stride(input->dims()));
151+
paddle::framework::Tensor input_stride_tensor;
152+
framework::TensorFromVector<int64_t>(res_in, context.device_context(),
153+
&input_stride_tensor);
154+
155+
auto* output = context.Output<framework::Tensor>("Out");
156+
auto* output_data = output->mutable_data<T>(context.GetPlace());
157+
auto output_dim = output->dims().Get();
158+
auto output_dim_size = output->dims().size();
159+
160+
int threads = PADDLE_CUDA_NUM_THREADS;
161+
auto numel = output->numel() / 2;
162+
int blocks = (numel + threads - 1) / threads;
163+
164+
auto rank = input_dim[input_dim_size - 1];
165+
DeterminantGrad<T><<<blocks, threads>>>(numel, output_data);
166+
}
167+
};
168+
169+
} // namespace operators
170+
} // namespace paddle
171+
172+
namespace ops = paddle::operators;
173+
namespace plat = paddle::platform;
174+
REGISTER_OP_CUDA_KERNEL(determinant, ops::DeterminantCUDAKernel<int>,
175+
ops::DeterminantCUDAKernel<int64_t>,
176+
ops::DeterminantCUDAKernel<float>,
177+
ops::DeterminantCUDAKernel<double>,
178+
ops::DeterminantCUDAKernel<bool>);
179+
180+
REGISTER_OP_CUDA_KERNEL(determinant_grad, ops::DeterminantGradCUDAKernel<int>,
181+
ops::DeterminantGradCUDAKernel<int64_t>,
182+
ops::DeterminantGradCUDAKernel<float>,
183+
ops::DeterminantGradCUDAKernel<double>);
184+
185+
REGISTER_OP_CUDA_KERNEL(slogdeterminant, ops::SlogDeterminantCUDAKernel<int>,
186+
ops::SlogDeterminantCUDAKernel<int64_t>,
187+
ops::SlogDeterminantCUDAKernel<float>,
188+
ops::SlogDeterminantCUDAKernel<double>,
189+
ops::SlogDeterminantCUDAKernel<bool>);
190+
191+
REGISTER_OP_CUDA_KERNEL(slogdeterminant_grad,
192+
ops::DeterminantGradCUDAKernel<int>,
193+
ops::SlogDeterminantGradCUDAKernel<int64_t>,
194+
ops::SlogDeterminantGradCUDAKernel<float>,
195+
ops::SlogDeterminantGradCUDAKernel<double>);

0 commit comments

Comments
 (0)