Skip to content

Commit

Permalink
Add Assert Op (PaddlePaddle#24280)
Browse files Browse the repository at this point in the history
1. To make ProgramTranslator to support `assert` grammar, this PR adds `assert` python API and C++ code. 

2. Fix a bug: graph_pattern_detector.h #include <gtest/gtest_prod.h> but didn't declared dependency at CMakeLists, which can cause single build failure.

3. Refactoring `Formatter` in print_op to make it reusable and reuse the formatter to print in assert op.
  • Loading branch information
zhhsplendid authored May 8, 2020
1 parent 8c296de commit 8a1a2af
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 151 deletions.
8 changes: 7 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)

SET(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits)
if (WITH_TESTING)
SET(GRAPH_PATTERN_DETECTOR_DEPS ${GRAPH_PATTERN_DETECTOR_DEPS} gtest)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})

cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ if (WITH_GPU)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)

# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
Expand Down Expand Up @@ -119,6 +120,7 @@ else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif()

cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
Expand Down
108 changes: 108 additions & 0 deletions paddle/fluid/operators/assert_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/tensor_formatter.h"

const char kCond[] = "Cond";
const char kData[] = "Data";
const char kSummarize[] = "summarize";

namespace paddle {
namespace operators {

using framework::LoDTensor;

class AssertOp : public framework::OperatorBase {
public:
AssertOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
const framework::Variable *cond_var_ptr = scope.FindVar(Input(kCond));
PADDLE_ENFORCE_NOT_NULL(cond_var_ptr,
platform::errors::NotFound(
"Input(Condition) of AssertOp is not found."));
const LoDTensor &cond = cond_var_ptr->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(), paddle::framework::make_ddim({1}),
platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.",
cond.dims().to_str()));

bool cond_data = GetCondData(cond);
if (cond_data) {
return;
}

TensorFormatter formatter;
formatter.SetSummarize(Attr<int64_t>(kSummarize));

const std::vector<std::string> &x_names = Inputs(kData);
for (const std::string &name : x_names) {
const framework::Variable *x_var_ptr = scope.FindVar(name);
const framework::LoDTensor &x_tensor = x_var_ptr->Get<LoDTensor>();
formatter.Print(x_tensor, name);
}

PADDLE_THROW(platform::errors::InvalidArgument(
"The condition variable '%s' of AssertOp must be "
"true, but received false",
Input(kCond)));
}
};

class AssertOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
kCond,
"The boolean scalar condition tensor which is asserted to be true.");
AddInput(kData,
"The tensors to print when the assert condition is not true.")
.AsDuplicable();
AddAttr<int64_t>(
kSummarize,
"The number of entries of each tensor to print when the "
"assert condition is not true. -1 means print all entries. If "
"the number of entries of a tensor is less then "
"summarize_num, this OP will print all entries of the tensor.")
.SetDefault(-1);
AddComment(
R"DOC(Assert the input Condition Tensor is true and print Tensors if the Condition Tensor is false.)DOC");
}
};

class AssertOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInputs(kCond), "Input", "Condition", "AssertOp");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
assert, ops::AssertOp, ops::AssertOpProtoMaker, ops::AssertOpInferShape,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
158 changes: 10 additions & 148 deletions paddle/fluid/operators/print_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/assign_op.h"
#include "paddle/fluid/operators/tensor_formatter.h"

namespace paddle {
namespace operators {
Expand All @@ -28,133 +29,6 @@ const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";

class LogGuard {
public:
inline LogGuard() { LogMutex().lock(); }

inline ~LogGuard() { LogMutex().unlock(); }

private:
static std::mutex &LogMutex() {
static std::mutex mtx;
return mtx;
}
};

struct Formater {
std::string message;
std::string name;
std::string dims;
std::type_index dtype{typeid(const char)};
std::string layout;
framework::LoD lod;
int summarize;
void *data{nullptr};
platform::Place place;
std::stringstream logs;

void operator()(size_t size) {
PrintName();
PrintMessage();
PrintLod();
PrintPlace();
PrintDims();
PrintLayout();
PrintDtype();
PrintData(size);
LogGuard guard;
CLOG << logs.str();
}

private:
void PrintPlace() { logs << " - place: " << place << std::endl; }
void PrintMessage() {
if (!message.empty()) {
logs << " - message: " << message << std::endl;
}
}
void PrintName() {
if (!name.empty()) {
logs << "Variable: " << name << std::endl;
}
}
void PrintDims() {
if (!dims.empty()) {
logs << " - shape: " << dims << std::endl;
}
}
void PrintDtype() {
if (!framework::IsType<const char>(dtype)) {
logs << " - dtype: " << platform::demangle(dtype.name()) << std::endl;
}
}
void PrintLayout() {
if (!layout.empty()) {
logs << " - layout: " << layout << std::endl;
}
}
void PrintLod() {
if (!lod.empty()) {
logs << " - lod: {";
for (auto level : lod) {
logs << "{";
bool is_first = true;
for (auto i : level) {
if (is_first) {
logs << i;
is_first = false;
} else {
logs << ", " << i;
}
}
logs << "}";
}
logs << "}" << std::endl;
}
}

void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
// print float
if (framework::IsType<const float>(dtype)) {
Display<float>(size);
} else if (framework::IsType<const double>(dtype)) {
Display<double>(size);
} else if (framework::IsType<const int>(dtype)) {
Display<int>(size);
} else if (framework::IsType<const int64_t>(dtype)) {
Display<int64_t>(size);
} else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size);
} else {
logs << " - data: unprintable type: " << dtype.name() << std::endl;
}
}

template <typename T>
void Display(size_t size) {
auto *d = reinterpret_cast<T *>(data);
logs << " - data: [";
if (summarize != -1) {
summarize = std::min(size, (size_t)summarize);
if (summarize > 0) {
logs << d[0];
for (int i = 1; i < summarize; ++i) {
logs << " " << d[i];
}
}
} else {
if (size > 0) {
logs << d[0];
for (size_t i = 1; i < size; ++i) {
logs << " " << d[i];
}
}
}
logs << "]" << std::endl;
}
};

// TODO(ChunweiYan) there should be some other printers for TensorArray
class PrintOp : public framework::OperatorBase {
public:
Expand Down Expand Up @@ -211,27 +85,15 @@ class PrintOp : public framework::OperatorBase {
TensorCopy(in_tensor, place, &printed_tensor);
}

Formater formater;
formater.place = place;
formater.message = Attr<std::string>("message");
if (Attr<bool>("print_tensor_name")) {
formater.name = printed_var_name;
}
if (Attr<bool>("print_tensor_type")) {
formater.dtype = framework::ToTypeIndex(printed_tensor.type());
}
if (Attr<bool>("print_tensor_shape")) {
formater.dims = printed_tensor.dims().to_str();
}
if (Attr<bool>("print_tensor_lod")) {
formater.lod = printed_tensor.lod();
}
if (Attr<bool>("print_tensor_layout")) {
formater.layout = framework::DataLayoutToString(printed_tensor.layout());
}
formater.summarize = Attr<int>("summarize");
formater.data = reinterpret_cast<void *>(printed_tensor.data<void>());
formater(printed_tensor.numel());
TensorFormatter formatter;
const std::string &name =
Attr<bool>("print_tensor_name") ? printed_var_name : "";
formatter.SetPrintTensorType(Attr<bool>("print_tensor_type"));
formatter.SetPrintTensorShape(Attr<bool>("print_tensor_shape"));
formatter.SetPrintTensorLod(Attr<bool>("print_tensor_lod"));
formatter.SetPrintTensorLayout(Attr<bool>("print_tensor_layout"));
formatter.SetSummarize(static_cast<int64_t>(Attr<int>("summarize")));
formatter.Print(printed_tensor, name, Attr<std::string>("message"));
}

private:
Expand Down
Loading

0 comments on commit 8a1a2af

Please sign in to comment.