Skip to content

Commit 294b58a

Browse files
author
ktlichkid
committed
Changed registered type
1 parent df80b6e commit 294b58a

File tree

2 files changed

+14
-81
lines changed

2 files changed

+14
-81
lines changed

paddle/fluid/operators/beam_search_op.cc

Lines changed: 13 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
197197
return stream.str();
198198
}
199199

200-
class BeamSearchOpMaker
201-
: public framework::OpProtoAndCheckerMaker {
200+
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
202201
public:
203202
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
204203
: OpProtoAndCheckerMaker(proto, op_checker) {
@@ -225,29 +224,15 @@ class BeamSearchOpMaker
225224
};
226225

227226
class BeamSearchOp : public framework::OperatorWithKernel {
228-
/*
229-
public:
230-
BeamSearchOp(const std::string& type,
231-
const framework::VariableNameMap& inputs,
232-
const framework::VariableNameMap& outputs,
233-
const framework::AttributeMap& attrs)
234-
: OperatorWithKernel(type, inputs, outputs, attrs) {}
235-
236-
BeamSearchOp(const BeamSearchOp& o)
237-
: framework::OperatorWithKernel(
238-
static_cast<const framework::OperatorBase&>(o)) {
239-
PADDLE_THROW("Not Implemented");
240-
}
241-
*/
242227
public:
243228
using framework::OperatorWithKernel::OperatorWithKernel;
244229

245230
protected:
246-
void InferShape(framework::InferShapeContext* ctx) const override {
231+
void InferShape(framework::InferShapeContext *ctx) const override {
247232
for (const std::string &arg :
248233
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
249-
PADDLE_ENFORCE(ctx->HasInput(arg),
250-
"BeamSearch need input argument '%s'", arg);
234+
PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'",
235+
arg);
251236
}
252237
for (const std::string &arg :
253238
std::vector<std::string>({"selected_ids", "selected_scores"})) {
@@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
263248
framework::OpKernelType kt = framework::OpKernelType(
264249
framework::ToDataType(
265250
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
266-
platform::CPUPlace());
251+
platform::CPUPlace());
267252
std::cout << "Get Expected type 2\n";
268-
// kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
269-
// std::cout << "Get Expected type 3\n";
270253
return kt;
271254
}
272-
/*
273-
private:
274-
void RunImpl(const framework::Scope& scope,
275-
const platform::Place& dev_place) const override {
276-
auto ids_var = scope.FindVar(Input("ids"));
277-
auto scores_var = scope.FindVar(Input("scores"));
278-
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
279-
PADDLE_ENFORCE_NOT_NULL(ids_var);
280-
PADDLE_ENFORCE_NOT_NULL(scores_var);
281-
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
282-
283-
auto& ids = ids_var->Get<framework::LoDTensor>();
284-
auto& scores = scores_var->Get<framework::LoDTensor>();
285-
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
286-
size_t level = Attr<int>("level");
287-
size_t beam_size = Attr<int>("beam_size");
288-
int end_id = Attr<int>("end_id");
289-
BeamSearch alg(ids, scores, level, beam_size, end_id);
290-
291-
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
292-
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
293-
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
294-
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
295-
auto& selected_ids_tensor =
296-
*selected_ids_var->GetMutable<framework::LoDTensor>();
297-
auto& selected_scores_tensor =
298-
*selected_scores_var->GetMutable<framework::LoDTensor>();
299-
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
300-
}
301-
*/
302255
};
303256

304257

305-
/*
306-
class BeamSearchInferShape : public framework::InferShapeBase {
307-
public:
308-
void operator()(framework::InferShapeContext *context) const override {
309-
for (const std::string &arg :
310-
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
311-
PADDLE_ENFORCE(context->HasInput(arg),
312-
"BeamSearch need input argument '%s'", arg);
313-
}
314-
for (const std::string &arg :
315-
std::vector<std::string>({"selected_ids", "selected_scores"})) {
316-
PADDLE_ENFORCE(context->HasOutput(arg),
317-
"BeamSearch need output argument '%s'", arg);
318-
}
319-
}
320-
};
321-
*/
322258
class BeamSearchInferVarType : public framework::VarTypeInference {
323259
public:
324260
void operator()(const framework::OpDesc &op_desc,
@@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
334270

335271
} // namespace operators
336272
} // namespace paddle
337-
/*
338-
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
339-
paddle::operators::BeamSearchProtoAndCheckerMaker,
340-
paddle::operators::BeamSearchInferShape,
341-
paddle::operators::BeamSearchInferVarType,
342-
paddle::framework::EmptyGradOpMaker);
343-
*/
273+
274+
344275
namespace ops = paddle::operators;
345-
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
346-
ops::BeamSearchOpMaker,
347-
ops::BeamSearchInferVarType);
276+
277+
REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker,
278+
ops::BeamSearchInferVarType);
348279
REGISTER_OP_CPU_KERNEL(
349280
beam_search,
350281
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
351-
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>);
282+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>,
283+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int>,
284+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/beam_search_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
195195
std::string ItemToString(const BeamSearch::Item& item);
196196

197197
template <typename DeviceContext, typename T>
198-
class BeamSearchOpKernel : public framework::OpKernel<T>{
198+
class BeamSearchOpKernel : public framework::OpKernel<T> {
199199
public:
200200
void Compute(const framework::ExecutionContext& context) const override {
201201
std::cout << "Compute 1\n";

0 commit comments

Comments
 (0)