@@ -877,6 +877,77 @@ void PadInferMeta(const MetaTensor& input,
877
877
out->set_dtype (input.dtype ());
878
878
}
879
879
880
+ void Pad3dInferMeta (const MetaTensor& x,
881
+ const ScalarArray& paddings_scalar_array,
882
+ const std::string& mode,
883
+ float value,
884
+ const std::string& data_format,
885
+ MetaTensor* out,
886
+ MetaConfig config) {
887
+ auto x_dim = x.dims ();
888
+ PADDLE_ENFORCE_EQ (x_dim.size (),
889
+ 5 ,
890
+ errors::InvalidArgument (
891
+ " The size of Input(X)'s dimension should be equal to "
892
+ " 5, but received %d. " ,
893
+ x_dim.size ()));
894
+
895
+ std::vector<int64_t > out_dims (x_dim.size ());
896
+ out_dims[0 ] = x_dim[0 ];
897
+ if (paddings_scalar_array.FromTensor ()) {
898
+ if (config.is_runtime ) {
899
+ PADDLE_ENFORCE_EQ (
900
+ paddings_scalar_array.GetData ().size (),
901
+ 6 ,
902
+ errors::InvalidArgument (" Shape of Input(Paddings) should be equal to "
903
+ " [6], but received [%d]." ,
904
+ paddings_scalar_array.GetData ().size ()));
905
+ }
906
+ out_dims[1 ] = x_dim[1 ];
907
+ out_dims[2 ] = x_dim[2 ];
908
+ out_dims[3 ] = x_dim[3 ];
909
+ } else {
910
+ auto paddings = paddings_scalar_array.GetData ();
911
+
912
+ PADDLE_ENFORCE_EQ (
913
+ paddings.size (),
914
+ 6 ,
915
+ errors::InvalidArgument (
916
+ " Size of paddings should be equal to 6, but received %d." ,
917
+ static_cast <int >(paddings.size ())));
918
+ if (data_format == " NCDHW" ) {
919
+ out_dims[1 ] = x_dim[1 ]; // channel
920
+ out_dims[2 ] = ((!config.is_runtime ) && (x_dim[2 ] < 0 ))
921
+ ? x_dim[2 ]
922
+ : (x_dim[2 ] + paddings[4 ] + paddings[5 ]); // depth
923
+
924
+ out_dims[3 ] = ((!config.is_runtime ) && (x_dim[3 ] < 0 ))
925
+ ? x_dim[3 ]
926
+ : (x_dim[3 ] + paddings[2 ] + paddings[3 ]); // height
927
+
928
+ out_dims[4 ] = ((!config.is_runtime ) && (x_dim[4 ] < 0 ))
929
+ ? x_dim[4 ]
930
+ : (x_dim[4 ] + paddings[0 ] + paddings[1 ]); // width
931
+ } else { // NDHWC
932
+ out_dims[4 ] = x_dim[4 ]; // channel
933
+
934
+ out_dims[1 ] = ((!config.is_runtime ) && (x_dim[1 ] < 0 ))
935
+ ? x_dim[1 ]
936
+ : (x_dim[1 ] + paddings[4 ] + paddings[5 ]); // depth
937
+ out_dims[2 ] = ((!config.is_runtime ) && (x_dim[2 ] < 0 ))
938
+ ? x_dim[2 ]
939
+ : (x_dim[2 ] + paddings[2 ] + paddings[3 ]); // height
940
+ out_dims[3 ] = ((!config.is_runtime ) && (x_dim[3 ] < 0 ))
941
+ ? x_dim[3 ]
942
+ : (x_dim[3 ] + paddings[0 ] + paddings[1 ]); // width
943
+ }
944
+ }
945
+
946
+ out->set_dims (phi::make_ddim (out_dims));
947
+ out->set_dtype (x.dtype ());
948
+ out->share_lod (x);
949
+ }
950
+
880
951
void PixelShuffleInferMeta (const MetaTensor& x,
881
952
int upscale_factor,
882
953
const std::string& data_format,
0 commit comments