Skip to content

Commit bf081f4

Browse files
authored
[CINN]optimize concat grad vjp (#70990)
* optimize concat grad vjp * update
1 parent 4e8e11b commit bf081f4

File tree

1 file changed

+6
-5
lines changed
  • paddle/fluid/primitive/decomp_rule/decomp_vjp

1 file changed

+6
-5
lines changed

paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,14 +552,15 @@ void concat_grad(const std::vector<Tensor>& x,
552552

553553
int x_num = x.size();
554554
std::vector<Tensor> x_grad_tmp;
555-
bool has_dynamic = false;
555+
556+
int neg_num = 0;
556557
for (size_t i = 0; i < x.size(); i++) {
557-
if (has_dynamic_shape(x[i].shape())) {
558-
has_dynamic = true;
559-
break;
558+
if (x[i].dims()[axis_value] < 0) {
559+
neg_num++;
560560
}
561561
}
562-
if (has_dynamic) {
562+
563+
if (neg_num > 1) {
563564
std::vector<Tensor> sections;
564565
for (int i = 0; i < x_num; i++) {
565566
sections.push_back(get_slice<T>(shape64<T>(x[i]), int64_t(axis_value)));

0 commit comments

Comments
 (0)