Skip to content

Commit 9b62071

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into tensor_inherit_from_dense_tensor
2 parents 2d91053 + 7f3b087 commit 9b62071

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+683
-522
lines changed

paddle/fluid/framework/lod_tensor.h

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,64 @@ bool CheckAbsLoD(const LoD& in, int tensor_height = -1);
108108
*/
109109
class LoDTensor : public Tensor {
110110
public:
111-
using Tensor::Tensor;
111+
LoDTensor() : Tensor() {}
112+
113+
explicit LoDTensor(const LoD& lod) : lod_(lod) {}
114+
115+
void set_lod(const LoD& lod) { lod_ = lod; }
116+
117+
const LoD& lod() const { return lod_; }
118+
119+
LoD* mutable_lod() { return &lod_; }
120+
121+
/*
122+
* Get the start offset and end offset of an element from LoD.
123+
*/
124+
std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const {
125+
PADDLE_ENFORCE_LT(
126+
level, NumLevels(),
127+
platform::errors::InvalidArgument(
128+
"The input level of LoD is invalid, it should be less than LoD "
129+
"size. The input level is %zu, the LoD size is %zu.",
130+
level, NumLevels()));
131+
PADDLE_ENFORCE_LT(elem, NumElements(level),
132+
platform::errors::InvalidArgument(
133+
"The input element of LoD is invalid, it should be "
134+
"less than the number of elements in its level."
135+
"The input element is %zu, the number of elements in "
136+
"its level is %zu.",
137+
elem, NumElements(level)));
138+
return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]);
139+
}
140+
141+
/*
142+
* Number of LoDTensor's levels, each level has units of data, for example,
143+
* in the sentence's view, article, paragraph, sentence are 3 levels.
144+
*/
145+
size_t NumLevels() const { return lod_.size(); }
146+
/*
147+
* Number of elements in a level.
148+
*/
149+
size_t NumElements(size_t level = 0) const {
150+
PADDLE_ENFORCE_LT(
151+
level, NumLevels(),
152+
platform::errors::InvalidArgument(
153+
"The input level of LoD is invalid, it should be less than LoD "
154+
"size. The input level is %zu, the LoD size is %zu.",
155+
level, NumLevels()));
156+
// the last offset is the end of last element
157+
return (lod_)[level].size() - 1;
158+
}
112159

113160
// Split LoDTensor and copy to each place specified in places.
114161
std::vector<LoDTensor> SplitLoDTensor(
115162
const std::vector<platform::Place> places) const;
116163

117164
void MergeLoDTensor(const std::vector<const LoDTensor*>& lod_tensors,
118165
platform::Place place);
166+
167+
private:
168+
LoD lod_;
119169
};
120170

121171
/*

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
413413
if (op_with_kernel == nullptr) {
414414
instr_node.OpBase()->Run(*local_scope, place_);
415415
} else {
416-
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
416+
// fit for pten
417+
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) {
418+
VLOG(4) << "Run pten kernel: " << op->Type();
419+
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
420+
<< &instr_node.DeviceContext();
421+
op_with_kernel->BuildPtenKernelContext(
422+
*instr_node.InnerRuntimeContext().get(),
423+
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));
424+
425+
(*instr_node.PtenKernel())(instr_node.PtenKernelContext());
426+
427+
op_with_kernel->WriteBackToOutputs(
428+
instr_node.InnerRuntimeContext().get());
429+
instr_node.PtenKernelContext()->ClearData();
430+
} else {
431+
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
432+
}
417433
}
418434
}
419435

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
2020
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
2121
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22+
#include "paddle/pten/core/kernel_factory.h"
2223

2324
PADDLE_DEFINE_EXPORTED_bool(
2425
new_executor_sequential_run, false,
2526
"Enable sequential execution for standalone executor, used for debug");
27+
DECLARE_bool(run_pten_kernel);
28+
2629
namespace paddle {
2730
namespace framework {
2831
namespace interpreter {
@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
338341
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
339342
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
340343
} else {
344+
auto op_with_kernel =
345+
static_cast<const framework::OperatorWithKernel*>(op);
341346
// construct RuntimeContext and analysis KernelType
342347
RuntimeContext runtime_context({}, {});
343348
runtime_context.inputs.swap(ins_map);
@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
350355
// TODO(Aurelius84): In case of control flow ops, they are NOT
351356
// inheritted
352357
// from OperatorWithKernel.
353-
static_cast<const framework::OperatorWithKernel*>(op)->InferShape(
354-
&infer_shape_ctx);
358+
op_with_kernel->InferShape(&infer_shape_ctx);
355359
}
356360

357361
auto kernels_iter = all_op_kernels.find(op->Type());
@@ -367,21 +371,25 @@ void build_op_func_list(const platform::Place& place,
367371
platform::DeviceContextPool::Instance();
368372
auto* dev_ctx = pool.Get(place);
369373
Scope scope;
370-
auto expected_kernel_key =
371-
dynamic_cast<const framework::OperatorWithKernel*>(op)
372-
->GetExpectedKernelType(
373-
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
374+
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
375+
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
374376

375377
// change device by the device_guard()
376378
apply_device_guard(op, place, &expected_kernel_key);
377379
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
378380

379381
// step 3. apply data transforms and insert data transfer ops
380382
VariableValueMap& ins_map_temp = runtime_context.inputs;
383+
384+
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
385+
// ApplyDataTransform
381386
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
382387
&op_func_node, vec_func_list, use_local_scope);
388+
op_with_kernel = static_cast<const framework::OperatorWithKernel*>(
389+
op_func_node.operator_base_.get());
390+
383391
// step 4. Run op kernel
384-
VLOG(3) << op->Type()
392+
VLOG(3) << op_with_kernel->Type()
385393
<< " : expected_kernel_key : " << expected_kernel_key;
386394

387395
if (platform::is_gpu_place(expected_kernel_key.place_)) {
@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
397405
}
398406
op_func_node.dev_ctx_ = dev_ctx;
399407

400-
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
408+
auto exec_ctx =
409+
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
401410

402411
auto kernel_iter = kernels.find(expected_kernel_key);
403412
PADDLE_ENFORCE_NE(
@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
406415
"Operator (%s) does not have kernel for %s.", op->Type(),
407416
KernelTypeToString(expected_kernel_key)));
408417

409-
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
410-
op_func_node.kernel_func_(exec_ctx);
418+
auto run_pten_kernel = false;
419+
420+
if (FLAGS_run_pten_kernel &&
421+
pten::KernelFactory::Instance().HasCompatiblePtenKernel(
422+
op_with_kernel->Type())) {
423+
op_with_kernel->ChoosePtenKernel(exec_ctx);
424+
run_pten_kernel = op_with_kernel->PtenKernel()->IsValid();
425+
}
426+
427+
if (run_pten_kernel) {
428+
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
429+
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
430+
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();
431+
432+
(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
433+
op_with_kernel->WriteBackToOutputs(&runtime_context);
434+
op_func_node.pt_kernel_context_->ClearData();
435+
} else {
436+
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
437+
op_func_node.kernel_func_(exec_ctx);
438+
}
411439

412440
// post-process grad_op.outputs if need cast complex grad into real grad.
413441
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
673673
return op_func_node_.kernel_func_;
674674
}
675675

676+
pten::Kernel* Instruction::PtenKernel() const {
677+
return op_func_node_.pt_kernel_;
678+
}
679+
680+
pten::KernelContext* Instruction::PtenKernelContext() const {
681+
return op_func_node_.pt_kernel_context_;
682+
}
683+
676684
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
677685

678686
OperatorBase* Instruction::OpBase() const {

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ struct OpFuncNode {
295295

296296
OpKernelComputeFunc kernel_func_;
297297
platform::DeviceContext* dev_ctx_; // not owned
298+
299+
// fit for pten kernel
300+
pten::Kernel* pt_kernel_{nullptr}; // not owned
301+
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
302+
298303
OpFuncType type_;
299304
};
300305

@@ -313,6 +318,10 @@ class Instruction {
313318

314319
OpKernelComputeFunc KernelFunc() const;
315320

321+
pten::Kernel* PtenKernel() const;
322+
323+
pten::KernelContext* PtenKernelContext() const;
324+
316325
OpFuncType KernelType() const;
317326

318327
OperatorBase* OpBase() const;

paddle/fluid/framework/operator.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
17911791

17921792
void OperatorWithKernel::BuildPtenKernelContext(
17931793
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
1794+
if (pt_kernel_context_ == nullptr) {
1795+
pt_kernel_context_.reset(new pten::KernelContext());
1796+
}
17941797
// TODO(chenweihang): now only work for very simple case,
17951798
// many cases need to be deal with later:
17961799
// 1. the input and output are not tensor

paddle/fluid/framework/operator.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
555555
virtual KernelSignature GetExpectedPtenKernelArgs(
556556
const ExecutionContext& ctx) const;
557557

558+
/* member functions for adapting to pten lib */
559+
void ChoosePtenKernel(const ExecutionContext& ctx) const;
560+
561+
void BuildPtenKernelContext(const RuntimeContext& ctx,
562+
platform::DeviceContext* dev_ctx) const;
563+
564+
void WriteBackToOutputs(RuntimeContext* ctx) const;
565+
566+
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
567+
568+
pten::KernelContext* PtenKernelContext() const {
569+
return pt_kernel_context_.get();
570+
}
571+
558572
private:
559573
void RunImpl(const Scope& scope, const platform::Place& place) const final;
560574
void RunImpl(const Scope& scope, const platform::Place& place,
@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
595609
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
596610
const std::string& name) const;
597611

598-
/* member functions for adapting to pten lib */
599-
void ChoosePtenKernel(const ExecutionContext& ctx) const;
600-
601-
void BuildPtenKernelContext(const RuntimeContext& ctx,
602-
platform::DeviceContext* dev_ctx) const;
603-
604-
void WriteBackToOutputs(RuntimeContext* ctx) const;
605-
606612
protected:
607613
mutable std::unique_ptr<OpKernelType> kernel_type_;
608614
mutable std::unique_ptr<OpKernelFunc> kernel_func_;

paddle/fluid/operators/cast_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020

2121
#include "paddle/pten/api/lib/utils/tensor_utils.h"
2222
#include "paddle/pten/include/core.h"
23-
#include "paddle/pten/include/manipulation.h"
23+
#include "paddle/pten/kernels/cast_kernel.h"
2424

2525
namespace paddle {
2626
namespace operators {

paddle/fluid/operators/expand_as_v2_op.cc

100644100755
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ limitations under the License. */
1212
#include "paddle/fluid/operators/expand_as_v2_op.h"
1313
#include <memory>
1414
#include <vector>
15+
#include "paddle/fluid/framework/op_version_registry.h"
1516

1617
namespace paddle {
1718
namespace operators {
@@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
5051
AddInput("X",
5152
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
5253
"X is the input to be expanded.");
54+
AddInput("Y",
55+
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
56+
"Expand X according to the shape of Y.")
57+
.AsDispensable();
5358
AddOutput("Out",
5459
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
5560
"The rank of Output(Out) have the same with Input(X). "
@@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
144149
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
145150
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
146151
#endif
152+
153+
REGISTER_OP_VERSION(expand_as_v2)
154+
.AddCheckpoint(
155+
R"ROC(fix expand_as_v2 and add new input [Y])ROC",
156+
paddle::framework::compatible::OpVersionDesc().NewInput(
157+
"Y", "Expand X according to the shape of Y"));

paddle/fluid/operators/expand_as_v2_op.h

100644100755
Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
9191
PADDLE_ENFORCE_NE(target_shape[i], 0,
9292
platform::errors::InvalidArgument(
9393
"The value of target shape cannot be zero."));
94-
if (vec_in_dims[i] != 1) {
94+
if (i < diff) {
95+
PADDLE_ENFORCE_GT(
96+
target_shape[i], 0,
97+
platform::errors::InvalidArgument(
98+
"The expanded size (%d) for non-existing dimensions must be "
99+
"positive for expand_as_v2 op.",
100+
target_shape[i]));
101+
repeat_times[i] = target_shape[i];
102+
} else if (target_shape[i] > 0) {
103+
if (vec_in_dims[i] != 1) {
104+
PADDLE_ENFORCE_EQ(
105+
vec_in_dims[i], target_shape[i],
106+
platform::errors::InvalidArgument(
107+
"The value (%d) of the non-singleton dimension does not match"
108+
" the corresponding value (%d) in shape for expand_as_v2 op.",
109+
vec_in_dims[i], target_shape[i]));
110+
repeat_times[i] = 1;
111+
} else {
112+
repeat_times[i] = target_shape[i];
113+
}
114+
} else {
95115
PADDLE_ENFORCE_EQ(
96-
vec_in_dims[i], target_shape[i],
116+
target_shape[i], -1,
97117
platform::errors::InvalidArgument(
98-
"The value (%d) of the non-singleton dimension does not match"
99-
" the corresponding value (%d) in "
100-
"target tensor for expand_as_v2 op.",
101-
vec_in_dims[i], target_shape[i]));
118+
"When the value in shape is negative for expand_as_v2 op, "
119+
"only -1 is supported, but the value received is %d.",
120+
target_shape[i]));
102121
repeat_times[i] = 1;
103-
} else {
104-
repeat_times[i] = target_shape[i];
105122
}
106123
}
107124
auto* out0 = context.Output<Tensor>("Out");

0 commit comments

Comments
 (0)