Skip to content

Commit d977bf7

Browse files
YuanRishengAurelius84
authored andcommitted
[PHI]Seperate xshape kernel from normal kernel (PaddlePaddle#44315)
* seperate xshape kernel from normal kernel * fix bugs in infermeta * fix compile bugs * fix compile bugs
1 parent 6f919fe commit d977bf7

21 files changed

+239
-61
lines changed

paddle/fluid/operators/einsum_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ namespace ops = paddle::operators;
106106

107107
DECLARE_INFER_SHAPE_FUNCTOR(einsum,
108108
EinsumInferShapeFunctor,
109-
PD_INFER_META(phi::EinsumInferMeta));
109+
PD_INFER_META(phi::EinsumRawInferMeta));
110110

111111
REGISTER_OPERATOR(einsum,
112112
ops::EinsumOp,

paddle/fluid/operators/squeeze_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ namespace ops = paddle::operators;
347347

348348
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
349349
SqueezeInferShapeFunctor,
350-
PD_INFER_META(phi::SqueezeInferMeta));
350+
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
351351

352352
REGISTER_OPERATOR(squeeze,
353353
ops::SqueezeOp,

paddle/fluid/operators/unsqueeze_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
347347

348348
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
349349
Unsqueeze2InferShapeFunctor,
350-
PD_INFER_META(phi::UnsqueezeInferMeta));
350+
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));
351351

352352
namespace ops = paddle::operators;
353353
REGISTER_OPERATOR(unsqueeze,

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ add_custom_command(
325325
${dygraph_api_header_file}
326326
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp}
327327
${dygraph_api_source_file}
328-
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file}
329-
${api_gen_base} ${api_gen_file}
328+
DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file}
329+
${im_api_gen_file} ${api_gen_base} ${api_gen_file}
330330
VERBATIM)
331331

332332
# generate wrapped infermeta

paddle/phi/api/yaml/legacy_api.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,10 @@
582582
args : (Tensor[] x, str equation)
583583
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
584584
infer_meta :
585-
func : EinsumInferMeta
585+
func : EinsumRawInferMeta
586586
param : [x, equation]
587587
kernel :
588-
func : einsum
588+
func : einsum_raw
589589
backward : einsum_grad
590590

591591
- api : elementwise_pow
@@ -2047,9 +2047,9 @@
20472047
args : (Tensor x, int[] axes)
20482048
output : Tensor(out), Tensor(xshape)
20492049
infer_meta :
2050-
func : SqueezeInferMeta
2050+
func : SqueezeWithXShapeInferMeta
20512051
kernel :
2052-
func : squeeze
2052+
func : squeeze_with_xshape
20532053
view: (x -> out)
20542054
intermediate : xshape
20552055
backward : squeeze_grad
@@ -2290,9 +2290,9 @@
22902290
args : (Tensor x, IntArray axis)
22912291
output : Tensor(out), Tensor(xshape)
22922292
infer_meta :
2293-
func : UnsqueezeInferMeta
2293+
func : UnsqueezeWithXShapeInferMeta
22942294
kernel :
2295-
func : unsqueeze
2295+
func : unsqueeze_with_xshape
22962296
view: (x -> out)
22972297
intermediate : xshape
22982298
backward : unsqueeze_grad

paddle/phi/infermeta/unary.cc

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
570570

571571
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
572572
const std::string& equation,
573-
MetaTensor* out,
574-
std::vector<MetaTensor*> inner_cache,
575-
std::vector<MetaTensor*> xshape) {
573+
MetaTensor* out) {
576574
// collect the following informations to prepare einsum.
577575
LabelMap labelshape(0);
578576
LabelMap labeltype(LabelType::Reduction);
@@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
609607
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
610608
out->set_dims(make_ddim(output_dims));
611609
out->set_dtype(inputs[0]->dtype());
610+
}
611+
612+
void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
613+
const std::string& equation,
614+
MetaTensor* out,
615+
std::vector<MetaTensor*> inner_cache,
616+
std::vector<MetaTensor*> xshape) {
617+
EinsumInferMeta(inputs, equation, out);
612618
for (size_t i = 0; i < xshape.size(); ++i) {
613619
if (xshape[i] != nullptr) {
614620
xshape[i]->set_dims(inputs[i]->dims());
@@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x,
24482454

24492455
void SqueezeInferMeta(const MetaTensor& x,
24502456
const std::vector<int>& axes,
2451-
MetaTensor* out,
2452-
MetaTensor* xshape) {
2457+
MetaTensor* out) {
24532458
const auto& x_dims = x.dims();
24542459
// Check input tensor dims (<6) Eigen limit.
24552460
PADDLE_ENFORCE_LE(x_dims.size(),
@@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x,
24692474
out->share_lod(x);
24702475
}
24712476

2477+
out->set_dtype(x.dtype());
2478+
}
2479+
2480+
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
2481+
const std::vector<int>& axes,
2482+
MetaTensor* out,
2483+
MetaTensor* xshape) {
2484+
SqueezeInferMeta(x, axes, out);
2485+
const auto& x_dims = x.dims();
24722486
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
24732487
xshape_dims[0] = 0;
24742488
for (int i = 0; i < x_dims.size(); ++i) {
24752489
xshape_dims[i + 1] = x_dims[i];
24762490
}
2477-
xshape->set_dims(phi::make_ddim(xshape_dims));
2478-
xshape->share_lod(x);
2479-
xshape->set_dtype(x.dtype());
2480-
out->set_dtype(x.dtype());
2491+
if (xshape) {
2492+
xshape->set_dims(phi::make_ddim(xshape_dims));
2493+
xshape->share_lod(x);
2494+
xshape->set_dtype(x.dtype());
2495+
}
24812496
}
24822497

24832498
void StridedSliceRawInferMeta(const MetaTensor& x,
@@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
33103325
void UnsqueezeInferMeta(const MetaTensor& x,
33113326
const IntArray& axes,
33123327
MetaTensor* out,
3313-
MetaTensor* xshape,
33143328
MetaConfig config) {
33153329
const auto& x_dims = x.dims();
33163330
// Validity Check: input tensor dims (<6).
@@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
33393353
}
33403354
out->set_dtype(x.dtype());
33413355
}
3342-
if (xshape) {
3343-
// set xshape dims.
3344-
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
3345-
xshape_dims[0] = 0;
3346-
for (int i = 0; i < x_dims.size(); ++i) {
3347-
xshape_dims[i + 1] = x_dims[i];
3348-
}
3356+
}
33493357

3358+
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
3359+
const IntArray& axes,
3360+
MetaTensor* out,
3361+
MetaTensor* xshape,
3362+
MetaConfig config) {
3363+
const auto& x_dims = x.dims();
3364+
UnsqueezeInferMeta(x, axes, out, config);
3365+
// set xshape dims.
3366+
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
3367+
xshape_dims[0] = 0;
3368+
for (int i = 0; i < x_dims.size(); ++i) {
3369+
xshape_dims[i + 1] = x_dims[i];
3370+
}
3371+
if (xshape) {
33503372
xshape->set_dims(phi::make_ddim(xshape_dims));
33513373
xshape->share_lod(x);
33523374
xshape->set_dtype(x.dtype());

paddle/phi/infermeta/unary.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x,
9797

9898
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
9999
const std::string& equation,
100-
MetaTensor* out,
101-
std::vector<MetaTensor*> inner_cache,
102-
std::vector<MetaTensor*> xshape);
100+
MetaTensor* out);
101+
102+
void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
103+
const std::string& equation,
104+
MetaTensor* out,
105+
std::vector<MetaTensor*> inner_cache,
106+
std::vector<MetaTensor*> xshape);
103107

104108
void ExpandInferMeta(const MetaTensor& x,
105109
const IntArray& shape,
@@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta,
341345

342346
void SqueezeInferMeta(const MetaTensor& x,
343347
const std::vector<int>& axes,
344-
MetaTensor* out,
345-
MetaTensor* xshape);
348+
MetaTensor* out);
349+
350+
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
351+
const std::vector<int>& axes,
352+
MetaTensor* out,
353+
MetaTensor* xshape);
346354

347355
void StridedSliceRawInferMeta(const MetaTensor& x,
348356
const std::vector<int>& axes,
@@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
470478
void UnsqueezeInferMeta(const MetaTensor& x,
471479
const IntArray& axes,
472480
MetaTensor* out,
473-
MetaTensor* xshape,
474481
MetaConfig config = MetaConfig());
475482

483+
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
484+
const IntArray& axes,
485+
MetaTensor* out,
486+
MetaTensor* xshape,
487+
MetaConfig config = MetaConfig());
488+
476489
void UnStackInferMeta(const MetaTensor& x,
477490
int axis,
478491
int num,

paddle/phi/kernels/cpu/einsum_kernel.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,20 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/einsum_impl.h"
2020

21-
PD_REGISTER_KERNEL(einsum,
21+
PD_REGISTER_KERNEL(einsum_raw,
2222
CPU,
2323
ALL_LAYOUT,
2424
phi::EinsumKernelRaw,
2525
float,
2626
double,
2727
phi::dtype::complex<float>,
2828
phi::dtype::complex<double>) {}
29+
30+
PD_REGISTER_KERNEL(einsum,
31+
CPU,
32+
ALL_LAYOUT,
33+
phi::EinsumKernel,
34+
float,
35+
double,
36+
phi::dtype::complex<float>,
37+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/squeeze_kernel.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze,
3232
int64_t,
3333
phi::dtype::complex<float>,
3434
phi::dtype::complex<double>) {}
35+
36+
PD_REGISTER_KERNEL(squeeze_with_xshape,
37+
CPU,
38+
ALL_LAYOUT,
39+
phi::SqueezeWithXShapeKernel,
40+
float,
41+
double,
42+
phi::dtype::bfloat16,
43+
bool,
44+
int,
45+
uint8_t,
46+
int8_t,
47+
int64_t,
48+
phi::dtype::complex<float>,
49+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/unsqueeze_kernel.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze,
3333
int64_t,
3434
phi::dtype::complex<float>,
3535
phi::dtype::complex<double>) {}
36+
37+
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
38+
CPU,
39+
ALL_LAYOUT,
40+
phi::UnsqueezeWithXShapeKernel,
41+
float,
42+
double,
43+
phi::dtype::bfloat16,
44+
bool,
45+
int,
46+
int16_t,
47+
uint8_t,
48+
int8_t,
49+
int64_t,
50+
phi::dtype::complex<float>,
51+
phi::dtype::complex<double>) {}

0 commit comments

Comments
 (0)