We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4e8e11b commit bf081f4Copy full SHA for bf081f4
paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
@@ -552,14 +552,15 @@ void concat_grad(const std::vector<Tensor>& x,
552
553
int x_num = x.size();
554
std::vector<Tensor> x_grad_tmp;
555
- bool has_dynamic = false;
+
556
+ int neg_num = 0;
557
for (size_t i = 0; i < x.size(); i++) {
- if (has_dynamic_shape(x[i].shape())) {
558
- has_dynamic = true;
559
- break;
+ if (x[i].dims()[axis_value] < 0) {
+ neg_num++;
560
}
561
562
- if (has_dynamic) {
563
+ if (neg_num > 1) {
564
std::vector<Tensor> sections;
565
for (int i = 0; i < x_num; i++) {
566
sections.push_back(get_slice<T>(shape64<T>(x[i]), int64_t(axis_value)));
0 commit comments