Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions paddle/fluid/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
return os;
}

DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) {
PADDLE_ENFORCE_GE(src.size(), 3,
platform::errors::InvalidArgument(
"The rank of src dim should be at least 3 "
"in flatten_to_3d, but received %d.",
src.size()));
PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true,
platform::errors::InvalidArgument(
"The num_row_dims should be inside [1, %d] "
"in flatten_to_3d, but received %d.",
src.size() - 1, num_row_dims));
PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true,
platform::errors::InvalidArgument(
"The num_col_dims should be inside [2, %d] "
"in flatten_to_3d, but received %d.",
src.size(), num_col_dims));
PADDLE_ENFORCE_GE(
num_col_dims, num_row_dims,
platform::errors::InvalidArgument(
"The num_row_dims should be less than num_col_dims in flatten_to_3d,"
"but received num_row_dims = %d, num_col_dims = %d.",
num_row_dims, num_col_dims));

return DDim({product(slice_ddim(src, 0, num_row_dims)),
product(slice_ddim(src, num_row_dims, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}

DDim flatten_to_2d(const DDim& src, int num_col_dims) {
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ int arity(const DDim& ddim);

std::ostream& operator<<(std::ostream&, const DDim&);

/**
* \brief Flatten dim to 3d
* e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6})
* flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30}
*/
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims);

// Reshape a tensor to a matrix. The matrix's first dimension(column length)
// will be the product of tensor's first `num_col_dims` dimensions.
DDim flatten_to_2d(const DDim& src, int num_col_dims);
Expand Down
89 changes: 89 additions & 0 deletions paddle/fluid/operators/eigvals_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eigvals_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A complex- or real-valued tensor with shape (*, n, n)"
"where * is zero or more batch dimensions");
AddOutput("Out",
"(Tensor) The output tensor with shape (*,n) cointaining the "
"eigenvalues of X.");
AddComment(R"DOC(eigvals operator
Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real.
)DOC");
}
};

class EigvalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals");

DDim x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(), x_dims));

if (ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim - 1],
platform::errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}

auto output_dims = vectorize(x_dims);
output_dims.resize(x_dims.size() - 1);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
}
};

class EigvalsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const {
auto input_dtype = ctx->GetInputDataType("X");
auto output_dtype = framework::IsComplexType(input_dtype)
? input_dtype
: framework::ToComplexType(input_dtype);
ctx->SetOutputDataType("Out", output_dtype);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker,
ops::EigvalsOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(eigvals,
ops::EigvalsKernel<plat::CPUDeviceContext, float>,
ops::EigvalsKernel<plat::CPUDeviceContext, double>,
ops::EigvalsKernel<plat::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EigvalsKernel<plat::CPUDeviceContext,
paddle::platform::complex<double>>);
129 changes: 129 additions & 0 deletions paddle/fluid/operators/eigvals_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <complex>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;

template <typename T>
struct PaddleComplex {
using Type = paddle::platform::complex<T>;
};
template <>
struct PaddleComplex<paddle::platform::complex<float>> {
using Type = paddle::platform::complex<float>;
};
template <>
struct PaddleComplex<paddle::platform::complex<double>> {
using Type = paddle::platform::complex<double>;
};

template <typename T>
struct StdComplex {
using Type = std::complex<T>;
};
template <>
struct StdComplex<paddle::platform::complex<float>> {
using Type = std::complex<float>;
};
template <>
struct StdComplex<paddle::platform::complex<double>> {
using Type = std::complex<double>;
};

template <typename T>
using PaddleCType = typename PaddleComplex<T>::Type;
template <typename T>
using StdCType = typename StdComplex<T>::Type;
template <typename T>
using EigenMatrixPaddle = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorPaddle = Eigen::Matrix<PaddleCType<T>, Eigen::Dynamic, 1>;
template <typename T>
using EigenMatrixStd =
Eigen::Matrix<StdCType<T>, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorStd = Eigen::Matrix<StdCType<T>, Eigen::Dynamic, 1>;

static void SpiltBatchSquareMatrix(const Tensor &input,
std::vector<Tensor> *output) {
DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];

DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
flattened_input_dims = framework::make_ddim({1, n_dim, n_dim});
}

Tensor flattened_input;
flattened_input.ShareDataWith(input);
flattened_input.Resize(flattened_input_dims);
(*output) = flattened_input.Split(1, 0);
}

template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");

auto input_type = input->type();
auto output_type = framework::IsComplexType(input_type)
? input_type
: framework::ToComplexType(input_type);
output->mutable_data(ctx.GetPlace(), output_type);

std::vector<Tensor> input_matrices;
SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices);

int n_dim = input_matrices[0].dims()[1];
int n_batch = input_matrices.size();

DDim output_dims = output->dims();
output->Resize(framework::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0);

Eigen::Map<EigenMatrixPaddle<T>> input_emp(NULL, n_dim, n_dim);
Eigen::Map<EigenVectorPaddle<T>> output_evp(NULL, n_dim);
EigenMatrixStd<T> input_ems;
EigenVectorStd<T> output_evs;

for (int i = 0; i < n_batch; ++i) {
new (&input_emp) Eigen::Map<EigenMatrixPaddle<T>>(
input_matrices[i].data<T>(), n_dim, n_dim);
new (&output_evp) Eigen::Map<EigenVectorPaddle<T>>(
output_vectors[i].data<PaddleCType<T>>(), n_dim);
input_ems = input_emp.template cast<StdCType<T>>();
output_evs = input_ems.eigenvalues();
output_evp = output_evs.template cast<PaddleCType<T>>();
}
output->Resize(output_dims);
}
};
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,4 @@ if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120)
endif()
set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400)
6 changes: 6 additions & 0 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,12 @@ def check_output_customized(self, checker, custom_place=None):
outs.sort(key=len)
checker(outs)

def check_output_with_place_customized(self, checker, place):
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
outs.sort(key=len)
checker(outs)

def _assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names):
Expand Down
Loading