Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/operators/expand_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class ExpandV2Op : public framework::OperatorWithKernel {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
out_shape[i] = x_dims[i];
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
Expand Down Expand Up @@ -174,7 +177,7 @@ class ExpandV2GradOp : public framework::OperatorWithKernel {
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);

for (size_t i = 0; i < expand_shape.size(); ++i) {
if (expand_shape[i] == -1 || x_dim_vec[i] == -1) {
if (expand_shape[i] < 0 || x_dim_vec[i] == -1) {
continue;
} else {
if (ctx->IsRuntime()) {
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def get_attr_expand_shape(list_expand_shape):
attrs_expand_shape = []
for idx, shape in enumerate(list_expand_shape):
if isinstance(shape, Variable):
attrs_expand_shape.append(-1)
attrs_expand_shape.append(-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好也加一个注释,并对应c++哪个文件op代码

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c++端有对应的解释。这个地方也确实应该加一个注释;稍后提一个新pr添加。

else:
attrs_expand_shape.append(shape)
assert shape > 0 or shape == -1, (
Expand Down