@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
197197 return stream.str ();
198198}
199199
200- class BeamSearchOpMaker
201- : public framework::OpProtoAndCheckerMaker {
200+ class BeamSearchOpMaker : public framework ::OpProtoAndCheckerMaker {
202201 public:
203202 BeamSearchOpMaker (OpProto *proto, OpAttrChecker *op_checker)
204203 : OpProtoAndCheckerMaker(proto, op_checker) {
@@ -225,29 +224,15 @@ class BeamSearchOpMaker
225224};
226225
227226class BeamSearchOp : public framework ::OperatorWithKernel {
228- /*
229- public:
230- BeamSearchOp(const std::string& type,
231- const framework::VariableNameMap& inputs,
232- const framework::VariableNameMap& outputs,
233- const framework::AttributeMap& attrs)
234- : OperatorWithKernel(type, inputs, outputs, attrs) {}
235-
236- BeamSearchOp(const BeamSearchOp& o)
237- : framework::OperatorWithKernel(
238- static_cast<const framework::OperatorBase&>(o)) {
239- PADDLE_THROW("Not Implemented");
240- }
241- */
242227 public:
243228 using framework::OperatorWithKernel::OperatorWithKernel;
244229
245230 protected:
246- void InferShape (framework::InferShapeContext* ctx) const override {
231+ void InferShape (framework::InferShapeContext * ctx) const override {
247232 for (const std::string &arg :
248233 std::vector<std::string>({" pre_ids" , " ids" , " scores" })) {
249- PADDLE_ENFORCE (ctx->HasInput (arg),
250- " BeamSearch need input argument '%s' " , arg);
234+ PADDLE_ENFORCE (ctx->HasInput (arg), " BeamSearch need input argument '%s' " ,
235+ arg);
251236 }
252237 for (const std::string &arg :
253238 std::vector<std::string>({" selected_ids" , " selected_scores" })) {
@@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
263248 framework::OpKernelType kt = framework::OpKernelType (
264249 framework::ToDataType (
265250 ctx.Input <framework::LoDTensor>(" pre_ids" )->type ()),
266- platform::CPUPlace ());
251+ platform::CPUPlace ());
267252 std::cout << " Get Expected type 2\n " ;
268- // kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
269- // std::cout << "Get Expected type 3\n";
270253 return kt;
271254 }
272- /*
273- private:
274- void RunImpl(const framework::Scope& scope,
275- const platform::Place& dev_place) const override {
276- auto ids_var = scope.FindVar(Input("ids"));
277- auto scores_var = scope.FindVar(Input("scores"));
278- auto pre_ids_var = scope.FindVar(Input("pre_ids"));
279- PADDLE_ENFORCE_NOT_NULL(ids_var);
280- PADDLE_ENFORCE_NOT_NULL(scores_var);
281- PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
282-
283- auto& ids = ids_var->Get<framework::LoDTensor>();
284- auto& scores = scores_var->Get<framework::LoDTensor>();
285- auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
286- size_t level = Attr<int>("level");
287- size_t beam_size = Attr<int>("beam_size");
288- int end_id = Attr<int>("end_id");
289- BeamSearch alg(ids, scores, level, beam_size, end_id);
290-
291- auto selected_ids_var = scope.FindVar(Output("selected_ids"));
292- auto selected_scores_var = scope.FindVar(Output("selected_scores"));
293- PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
294- PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
295- auto& selected_ids_tensor =
296- *selected_ids_var->GetMutable<framework::LoDTensor>();
297- auto& selected_scores_tensor =
298- *selected_scores_var->GetMutable<framework::LoDTensor>();
299- alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
300- }
301- */
302255};
303256
304257
305- /*
306- class BeamSearchInferShape : public framework::InferShapeBase {
307- public:
308- void operator()(framework::InferShapeContext *context) const override {
309- for (const std::string &arg :
310- std::vector<std::string>({"pre_ids", "ids", "scores"})) {
311- PADDLE_ENFORCE(context->HasInput(arg),
312- "BeamSearch need input argument '%s'", arg);
313- }
314- for (const std::string &arg :
315- std::vector<std::string>({"selected_ids", "selected_scores"})) {
316- PADDLE_ENFORCE(context->HasOutput(arg),
317- "BeamSearch need output argument '%s'", arg);
318- }
319- }
320- };
321- */
322258class BeamSearchInferVarType : public framework ::VarTypeInference {
323259 public:
324260 void operator ()(const framework::OpDesc &op_desc,
@@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
334270
335271} // namespace operators
336272} // namespace paddle
337- /*
338- REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
339- paddle::operators::BeamSearchProtoAndCheckerMaker,
340- paddle::operators::BeamSearchInferShape,
341- paddle::operators::BeamSearchInferVarType,
342- paddle::framework::EmptyGradOpMaker);
343- */
273+
274+
344275namespace ops = paddle::operators;
345- REGISTER_OP_WITHOUT_GRADIENT (beam_search, ops::BeamSearchOp,
346- ops::BeamSearchOpMaker,
347- ops::BeamSearchInferVarType);
276+
277+ REGISTER_OPERATOR (beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker,
278+ ops::BeamSearchInferVarType);
348279REGISTER_OP_CPU_KERNEL (
349280 beam_search,
350281 ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float >,
351- ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double >);
282+ ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double >,
283+ ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int >,
284+ ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int64_t >);
0 commit comments