Skip to content

Commit e4a134a

Browse files
MingMingShangTianchenwhqlzyfncgYuanRishengShixiaowei02
authored
support multiply inputs and outputs (PaddlePaddle#36851)
* initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 2309149. * Move cpu, cuda and other device code into kernels (#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (PaddlePaddle#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (PaddlePaddle#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (PaddlePaddle#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang <chenweihang@baidu.com> * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (PaddlePaddle#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (PaddlePaddle#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (PaddlePaddle#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (PaddlePaddle#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang <chenweihang@baidu.com> * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (PaddlePaddle#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * polish some details * polish kernel signature details * fix a bug about offsets of the tensor, test=develop (PaddlePaddle#31) Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * support multiply inputs and outputs * rm attrs {} * fix multioutputs bug * merge develop * remove unsed header file * add missing & in const reference * modify inputAt, outputAt to inputBetween, outputBetween Co-authored-by: Chen Weihang <chenweihang@baidu.com> Co-authored-by: zyfncg <1370305206@qq.com> Co-authored-by: YuanRisheng <yuanrisheng@baidu.com> Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent 4a7f1a0 commit e4a134a

File tree

4 files changed

+112
-23
lines changed

4 files changed

+112
-23
lines changed

paddle/pten/core/kernel_context.h

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,37 +52,37 @@ class KernelContext {
5252
}
5353

5454
void EmplaceBackInput(std::shared_ptr<TensorBase> input) {
55+
int index = inputs_.size();
5556
inputs_.emplace_back(std::move(input));
5657
// Record the start and end index of the input
57-
int index = inputs_.size();
5858
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
5959
}
6060

6161
void EmplaceBackInputs(
62-
paddle::SmallVector<std::shared_ptr<TensorBase>> inputs) {
62+
const paddle::SmallVector<std::shared_ptr<TensorBase>>& inputs) {
63+
int index = inputs_.size();
6364
for (auto in : inputs) {
64-
inputs_.emplace_back(in);
65+
inputs_.emplace_back(std::move(in));
6566
}
6667
// Record the start and end index of the input
67-
int index = inputs_.size();
6868
input_range_.emplace_back(
6969
std::pair<int, int>(index, index + inputs.size()));
7070
}
7171

7272
void EmplaceBackOutput(std::shared_ptr<TensorBase> output) {
73+
int index = outputs_.size();
7374
outputs_.emplace_back(std::move(output));
7475
// Record the start and end index of the input
75-
int index = outputs_.size();
7676
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
7777
}
7878

7979
void EmplaceBackOutputs(
80-
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
80+
const paddle::SmallVector<std::shared_ptr<TensorBase>>& outputs) {
81+
int index = outputs_.size();
8182
for (auto out : outputs) {
82-
outputs_.emplace_back(out);
83+
outputs_.emplace_back(std::move(out));
8384
}
8485
// Record the start and end index of the input
85-
int index = outputs_.size();
8686
output_range_.emplace_back(
8787
std::pair<int, int>(index, index + outputs.size()));
8888
}
@@ -96,11 +96,40 @@ class KernelContext {
9696
return static_cast<const TensorType&>(*(inputs_.at(idx)));
9797
}
9898

99+
template <typename TensorType>
100+
std::vector<TensorType> InputBetween(size_t start, size_t end) const {
101+
std::vector<TensorType> v;
102+
for (size_t i = start; i < end; ++i) {
103+
auto t = std::dynamic_pointer_cast<TensorType>(inputs_.at(i));
104+
v.emplace_back(std::move(*t.get()));
105+
}
106+
107+
return v;
108+
}
109+
110+
const std::pair<int, int>& InputRangeAt(size_t idx) const {
111+
return input_range_.at(idx);
112+
}
113+
114+
const std::pair<int, int>& OutputRangeAt(size_t idx) const {
115+
return output_range_.at(idx);
116+
}
117+
99118
template <typename TensorType>
100119
TensorType* MutableOutputAt(size_t idx) {
101120
return static_cast<TensorType*>(outputs_.at(idx).get());
102121
}
103122

123+
template <typename TensorType>
124+
std::vector<TensorType*> MutableOutputBetween(size_t start, size_t end) {
125+
std::vector<TensorType*> v;
126+
for (size_t i = start; i < end; ++i) {
127+
v.emplace_back(static_cast<TensorType*>(outputs_.at(i).get()));
128+
}
129+
130+
return v;
131+
}
132+
104133
template <typename AttrType>
105134
AttrType AttrAt(size_t idx) const {
106135
try {

paddle/pten/core/kernel_registry.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,17 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
6262
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
6363
args_def->AppendInput(
6464
default_key.backend(), default_tensor_layout, default_key.dtype());
65+
} else if (arg_type ==
66+
std::type_index(typeid(const std::vector<DenseTensor>&))) {
67+
args_def->AppendInput(
68+
default_key.backend(), default_tensor_layout, default_key.dtype());
6569
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
6670
args_def->AppendOutput(
6771
default_key.backend(), default_tensor_layout, default_key.dtype());
72+
} else if (arg_type ==
73+
std::type_index(typeid(std::vector<DenseTensor*>))) {
74+
args_def->AppendOutput(
75+
default_key.backend(), default_tensor_layout, default_key.dtype());
6876
} else {
6977
// Attribute deal with
7078
// TODO(chenweihang): now here allow any types of attribute, maybe

paddle/pten/core/kernel_utils.h

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,30 @@ using XPUContext = paddle::platform::XPUDeviceContext;
7979
"Kernel's Input should appear before Attributes."); \
8080
static_assert(out_idx == 0, \
8181
"Kernel's Input should appear before Outputs."); \
82-
const tensor_type& arg = ctx->InputAt<tensor_type>(in_idx); \
82+
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
83+
const tensor_type& arg = ctx->InputAt<tensor_type>(range.first); \
84+
KernelCallHelper<Tail...>:: \
85+
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
86+
ctx, pargs..., arg); \
87+
} \
88+
}
89+
90+
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \
91+
template <typename... Tail> \
92+
struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \
93+
template <int dev_ctx_idx, \
94+
int in_idx, \
95+
int attr_idx, \
96+
int out_idx, \
97+
typename... PreviousArgs> \
98+
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
99+
static_assert(attr_idx == 0, \
100+
"Kernel's Input should appear before Attributes."); \
101+
static_assert(out_idx == 0, \
102+
"Kernel's Input should appear before Outputs."); \
103+
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
104+
std::vector<tensor_type> arg = std::move( \
105+
ctx->InputBetween<tensor_type>(range.first, range.second)); \
83106
KernelCallHelper<Tail...>:: \
84107
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
85108
ctx, pargs..., arg); \
@@ -104,20 +127,39 @@ using XPUContext = paddle::platform::XPUDeviceContext;
104127
} \
105128
}
106129

107-
#define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \
108-
template <typename... Tail> \
109-
struct KernelCallHelper<tensor_type*, Tail...> { \
110-
template <int dev_ctx_idx, \
111-
int in_idx, \
112-
int attr_idx, \
113-
int out_idx, \
114-
typename... PreviousArgs> \
115-
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
116-
tensor_type* arg = ctx->MutableOutputAt<tensor_type>(out_idx); \
117-
KernelCallHelper<Tail...>:: \
118-
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
119-
ctx, pargs..., arg); \
120-
} \
130+
#define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \
131+
template <typename... Tail> \
132+
struct KernelCallHelper<tensor_type*, Tail...> { \
133+
template <int dev_ctx_idx, \
134+
int in_idx, \
135+
int attr_idx, \
136+
int out_idx, \
137+
typename... PreviousArgs> \
138+
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
139+
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
140+
tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \
141+
KernelCallHelper<Tail...>:: \
142+
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
143+
ctx, pargs..., arg); \
144+
} \
145+
}
146+
147+
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \
148+
template <typename... Tail> \
149+
struct KernelCallHelper<std::vector<tensor_type*>, Tail...> { \
150+
template <int dev_ctx_idx, \
151+
int in_idx, \
152+
int attr_idx, \
153+
int out_idx, \
154+
typename... PreviousArgs> \
155+
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
156+
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
157+
std::vector<tensor_type*> arg = std::move( \
158+
ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \
159+
KernelCallHelper<Tail...>:: \
160+
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
161+
ctx, pargs..., arg); \
162+
} \
121163
}
122164

123165
template <typename T>
@@ -152,6 +194,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
152194
/* Input Helpers */
153195

154196
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
197+
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
155198
// TODO(chenweihang): adapt SelectedRows
156199
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);
157200

@@ -168,6 +211,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
168211
/* Output Helpers */
169212

170213
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
214+
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
171215
// TODO(chenweihang): adapt SelectedRows
172216
// PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor);
173217

paddle/pten/hapi/lib/kernel_dispatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
122122
key_set.dtype = x.type();
123123
}
124124

125+
void operator()(const std::vector<Tensor>& x) {
126+
key_set.backend_set =
127+
key_set.backend_set | detail::GetTensorBackendSet(x[0]);
128+
// TODO(chenweihang): selecte multi layout and dtype
129+
key_set.layout = x[0].layout();
130+
key_set.dtype = x[0].type();
131+
}
132+
125133
// skip other type args, these args don't used in kernel selection
126134
template <typename T>
127135
void operator()(const T& x) {

0 commit comments

Comments
 (0)