-
Notifications
You must be signed in to change notification settings - Fork 5.7k
While op forward for sentimental analysis #6140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
25b0c37
9d48911
e0a2300
dab22a9
52997cc
4fa4ec9
dc08ac5
bf38d85
a0a695b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -466,7 +466,12 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { | |
auto var = block_.FindVarRecursive(name); | ||
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); | ||
try { | ||
return framework::make_ddim(var->Shape()); | ||
auto shape = var->Shape(); | ||
if (shape.empty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In which case, the shape is empty? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The final step of RNN network. Memory's gradient could be empty. |
||
return framework::make_ddim({0UL}); | ||
} else { | ||
return framework::make_ddim(var->Shape()); | ||
} | ||
} catch (...) { | ||
VLOG(5) << "GetDim of variable " << name << " error"; | ||
std::rethrow_exception(std::current_exception()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ class WriteToArrayOp : public ArrayOp { | |
void Run(const framework::Scope &scope, | ||
const platform::DeviceContext &dev_ctx) const override { | ||
auto *x = scope.FindVar(Input("X")); | ||
PADDLE_ENFORCE(x != nullptr, "X must be set"); | ||
if (x == nullptr) return; | ||
auto &x_tensor = x->Get<framework::LoDTensor>(); | ||
size_t offset = GetOffset(scope, dev_ctx); | ||
auto *out = | ||
|
@@ -76,7 +76,9 @@ class WriteToArrayInferShape : public framework::InferShapeBase { | |
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index"); | ||
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1, | ||
"The number of element of subscript index must be 1"); | ||
PADDLE_ENFORCE(context->HasInput("X"), NotHasXError()); | ||
if (!context->HasInput("X")) { | ||
return; | ||
} | ||
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError()); | ||
context->SetOutputDim("Out", context->GetInputDim("X")); | ||
} | ||
|
@@ -99,9 +101,10 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { | |
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name), | ||
"Cannot found %s", out_name); | ||
out.SetType(framework::VarDesc::LOD_TENSOR_ARRAY); | ||
auto &x = | ||
detail::Ref(block->FindVarRecursive(x_name), "Cannot found %s", x_name); | ||
out.SetDataType(x.GetDataType()); | ||
auto *x = block->FindVarRecursive(x_name); | ||
if (x != nullptr) { | ||
out.SetDataType(x->GetDataType()); | ||
} | ||
} | ||
}; | ||
|
||
|
@@ -121,10 +124,13 @@ class ReadFromArrayOp : public ArrayOp { | |
PADDLE_ENFORCE(out != nullptr, "Out must be set"); | ||
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); | ||
size_t offset = GetOffset(scope, dev_ctx); | ||
PADDLE_ENFORCE_LT(offset, x_array.size()); | ||
framework::CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx, | ||
out_tensor); | ||
out_tensor->set_lod(x_array[offset].lod()); | ||
if (offset < x_array.size()) { | ||
framework::CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx, | ||
out_tensor); | ||
out_tensor->set_lod(x_array[offset].lod()); | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In which case will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Final timestep, memory |
||
VLOG(10) << "offset " << offset << " >= " << x_array.size(); | ||
} | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, the backward of Control Operator is not calculating gradient. So
no_grad_vars
does no sense to Control Operator.