Skip to content

Add get lod for debug #7375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
shape_inference data_transform lod_tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)

Expand Down
6 changes: 6 additions & 0 deletions paddle/framework/lod_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
return os;
}

std::string LoDToString(const LoD &lod) {
std::ostringstream stream;
stream << lod;
return stream.str();
}

LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end) {
PADDLE_ENFORCE_LT(level, in.size());
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/lod_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ using LoD = std::vector<Vector<size_t>>;
std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t);

std::string LoDToString(const LoD& lod);

LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin,
size_t elem_end);
/*
Expand Down
25 changes: 22 additions & 3 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return DDim({-1});
} else if (var->IsType<LoDTensor>()) {
}

if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
Expand All @@ -85,6 +87,21 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
}
}

static LoD GetLoD(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}});

if (var == nullptr) {
return default_lod;
}

if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().lod();
} else {
return default_lod;
}
}

std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
Expand Down Expand Up @@ -126,7 +143,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
ss << "(" << GetDims(*scope, input.second[i]) << ")";
ss << "[" << GetDims(*scope, input.second[i]) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) {
ss << ", ";
Expand All @@ -145,7 +163,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
ss << "(" << GetDims(*scope, output.second[i]) << ")";
ss << "[" << GetDims(*scope, output.second[i]) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) {
ss << ", ";
Expand Down