Skip to content

Commit 382e460

Browse files
[Phi]add pad3d kernel into phi (#40701)
* add pad3d kernel into phi * add pad3d infermeta * fix build error * remove raw pad3d infershape function
1 parent 0e1191f commit 382e460

File tree

11 files changed

+2347
-1523
lines changed

11 files changed

+2347
-1523
lines changed

paddle/fluid/operators/pad3d_op.cc

Lines changed: 7 additions & 730 deletions
Large diffs are not rendered by default.

paddle/fluid/operators/pad3d_op.cu

Lines changed: 0 additions & 793 deletions
This file was deleted.

paddle/phi/infermeta/unary.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,77 @@ void PadInferMeta(const MetaTensor& input,
877877
out->set_dtype(input.dtype());
878878
}
879879

880+
void Pad3dInferMeta(const MetaTensor& x,
881+
const ScalarArray& paddings_scalar_array,
882+
const std::string& mode,
883+
float value,
884+
const std::string& data_format,
885+
MetaTensor* out,
886+
MetaConfig config) {
887+
auto x_dim = x.dims();
888+
PADDLE_ENFORCE_EQ(x_dim.size(),
889+
5,
890+
errors::InvalidArgument(
891+
"The size of Input(X)'s dimension should be equal to "
892+
"5, but received %d. ",
893+
x_dim.size()));
894+
895+
std::vector<int64_t> out_dims(x_dim.size());
896+
out_dims[0] = x_dim[0];
897+
if (paddings_scalar_array.FromTensor()) {
898+
if (config.is_runtime) {
899+
PADDLE_ENFORCE_EQ(
900+
paddings_scalar_array.GetData().size(),
901+
6,
902+
errors::InvalidArgument("Shape of Input(Paddings) should be equal to "
903+
"[6], but received [%d].",
904+
paddings_scalar_array.GetData().size()));
905+
}
906+
out_dims[1] = x_dim[1];
907+
out_dims[2] = x_dim[2];
908+
out_dims[3] = x_dim[3];
909+
} else {
910+
auto paddings = paddings_scalar_array.GetData();
911+
912+
PADDLE_ENFORCE_EQ(
913+
paddings.size(),
914+
6,
915+
errors::InvalidArgument(
916+
"Size of paddings should be equal to 6, but received %d.",
917+
static_cast<int>(paddings.size())));
918+
if (data_format == "NCDHW") {
919+
out_dims[1] = x_dim[1]; // channel
920+
out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0))
921+
? x_dim[2]
922+
: (x_dim[2] + paddings[4] + paddings[5]); // depth
923+
924+
out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0))
925+
? x_dim[3]
926+
: (x_dim[3] + paddings[2] + paddings[3]); // height
927+
928+
out_dims[4] = ((!config.is_runtime) && (x_dim[4] < 0))
929+
? x_dim[4]
930+
: (x_dim[4] + paddings[0] + paddings[1]); // width
931+
} else { // NDHWC
932+
out_dims[4] = x_dim[4]; // channel
933+
934+
out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0))
935+
? x_dim[1]
936+
: (x_dim[1] + paddings[4] + paddings[5]); // depth
937+
out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0))
938+
? x_dim[2]
939+
: (x_dim[2] + paddings[2] + paddings[3]); // height
940+
out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0))
941+
? x_dim[3]
942+
: (x_dim[3] + paddings[0] + paddings[1]); // width
943+
}
944+
}
945+
946+
out->set_dims(phi::make_ddim(out_dims));
947+
out->set_dtype(x.dtype());
948+
out->share_lod(x);
949+
}
950+
880951
void PixelShuffleInferMeta(const MetaTensor& x,
881952
int upscale_factor,
882953
const std::string& data_format,

paddle/phi/infermeta/unary.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ void PadInferMeta(const MetaTensor& input,
147147
MetaTensor* out,
148148
MetaConfig config = MetaConfig());
149149

150+
void Pad3dInferMeta(const MetaTensor& x,
151+
const ScalarArray& paddings,
152+
const std::string& mode,
153+
float value,
154+
const std::string& data_format,
155+
MetaTensor* out,
156+
MetaConfig config = MetaConfig());
157+
150158
void PixelShuffleInferMeta(const MetaTensor& x,
151159
int upscale_factor,
152160
const std::string& data_format,

0 commit comments

Comments
 (0)