Skip to content

Commit cc00a23

Browse files
[Prim][PIR] support instancenorm op dynamic forward in prim pir (#64598)
* support dynamic instancenorm * fix narrow convert * fix dtype * update * remove empty_shape * Update composite.h * Update composite.h * fix scalar
1 parent b8ec413 commit cc00a23

File tree

3 files changed

+96
-9
lines changed

3 files changed

+96
-9
lines changed

paddle/fluid/primitive/base/decomp_trans.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,13 @@ std::unordered_set<std::string> decomp_op_contain_none = {"pd_op.squeeze",
4646
"pd_op.batch_norm_",
4747
"pd_op.dropout"};
4848
//
49-
std::unordered_set<std::string> dynamic_shape_blacklist = {
50-
"pd_op.squeeze",
51-
"pd_op.unsqueeze",
52-
"pd_op.batch_norm",
53-
"pd_op.batch_norm_",
54-
"pd_op.bmm",
55-
"pd_op.flatten",
56-
"pd_op.instance_norm",
57-
"pd_op.one_hot"};
49+
std::unordered_set<std::string> dynamic_shape_blacklist = {"pd_op.squeeze",
50+
"pd_op.unsqueeze",
51+
"pd_op.batch_norm",
52+
"pd_op.batch_norm_",
53+
"pd_op.bmm",
54+
"pd_op.flatten",
55+
"pd_op.one_hot"};
5856

5957
namespace {
6058
std::set<std::string> StringSplit(const std::string& str) {

paddle/fluid/primitive/composite/composite.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,78 @@ std::tuple<Tensor, Tensor, Tensor> instance_norm_decomp(
953953
const paddle::optional<Tensor>& scale,
954954
const paddle::optional<Tensor>& bias,
955955
float epsilon) {
956+
if (has_dynamic_shape(x.shape())) {
957+
auto org_dtype = x.dtype();
958+
Tensor x_cast = x;
959+
960+
bool need_cast = is_half_dtype(org_dtype);
961+
if (need_cast) {
962+
x_cast = cast<T>(x, DataType::FLOAT32);
963+
}
964+
965+
std::vector<int64_t> axis;
966+
auto x_dim = x.shape();
967+
for (size_t i = 2; i < x_dim.size(); i++) {
968+
axis.push_back(static_cast<int64_t>(i));
969+
}
970+
971+
// out = (x - mean(x)) / sqrt(var + epsilon))
972+
// var = mean((x-mean(x))^2)
973+
auto mean_ = mean_decomp<T>(x_cast, axis, true);
974+
auto difference = x_cast - mean_;
975+
auto var_tmp1 = difference * difference;
976+
auto variance = mean_decomp<T>(var_tmp1, axis, true);
977+
auto var_shape = shape<T>(variance);
978+
auto var_tmp3 = variance + full_scalar<T>(epsilon, variance.dtype());
979+
auto rsqrt_var = rsqrt<T>(var_tmp3);
980+
auto out = difference * rsqrt_var;
981+
982+
int dim_size = x_dim.size();
983+
auto x_shape_tensor = shape<T>(x);
984+
std::vector<Tensor> slice_shape_concat;
985+
986+
auto shape_1 = full<T>({1}, 1, x_shape_tensor.dtype());
987+
auto shape_2 =
988+
cast<T>(get_slice<T>(x_shape_tensor, 1), x_shape_tensor.dtype());
989+
auto shape_3 = full<T>({dim_size - 2}, 1, x_shape_tensor.dtype());
990+
991+
slice_shape_concat.push_back(shape_1);
992+
slice_shape_concat.push_back(shape_2);
993+
slice_shape_concat.push_back(shape_3);
994+
auto slice_shape_tensor = concat<T>(slice_shape_concat, 0);
995+
996+
Tensor scale_cast;
997+
if (scale) {
998+
scale_cast =
999+
backend::reshape_with_tensor<T>(scale.get(), slice_shape_tensor);
1000+
if (need_cast) {
1001+
scale_cast = cast<T>(scale_cast, DataType::FLOAT32);
1002+
}
1003+
out = out * scale_cast;
1004+
}
1005+
Tensor bias_cast;
1006+
if (bias) {
1007+
bias_cast =
1008+
backend::reshape_with_tensor<T>(bias.get(), slice_shape_tensor);
1009+
if (need_cast) {
1010+
bias_cast = cast<T>(bias_cast, DataType::FLOAT32);
1011+
}
1012+
out = out + bias_cast;
1013+
}
1014+
1015+
std::vector<int64_t> res_shape(1, -1);
1016+
auto mean_out = reshape<T>(mean_, res_shape);
1017+
auto variance_out = reshape<T>(rsqrt_var, res_shape);
1018+
1019+
Tensor res;
1020+
if (need_cast) {
1021+
res = cast<T>(out, org_dtype);
1022+
} else {
1023+
res = out;
1024+
}
1025+
1026+
return std::make_tuple(res, mean_out, variance_out);
1027+
}
9561028
auto org_dtype = x.dtype();
9571029
Tensor x_cast = x;
9581030

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def layer_norm_net1(x):
164164
return paddle.nn.functional.layer_norm(x, x.shape[1:])
165165

166166

167+
def instance_norm_net(x):
168+
return paddle.nn.functional.instance_norm(x)
169+
170+
167171
def flatten_net(x):
168172
return paddle.flatten(x, 1, 2)
169173

@@ -488,6 +492,19 @@ def setUp(self):
488492
self.tol = 5e-6
489493

490494

495+
class TestPrimInstancenorm(TestPrimBase):
496+
def setUp(self):
497+
np.random.seed(2023)
498+
self.shape_x = [2, 32, 128]
499+
self.dtype_x = "float32"
500+
self.init_x_shape = [None, None, None]
501+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
502+
self.net = instance_norm_net
503+
self.necessary_ops = "pd_op.instance_norm"
504+
self.enable_cinn = False
505+
self.tol = 5e-6
506+
507+
491508
class TestPrimGroupNorm1(TestPrimBase):
492509
def setUp(self):
493510
np.random.seed(2023)

0 commit comments

Comments
 (0)