Skip to content

Commit 84dd376

Browse files
committed
fix full selectew_rows bug
1 parent 50ae8dc commit 84dd376

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

paddle/pten/ops/compat/fill_constant_sig.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,45 +70,50 @@ KernelSignature FillConstantOpArgumentMapping(
7070
if (ctx.HasInput("ShapeTensor")) {
7171
if (ctx.HasInput("ValueTensor")) {
7272
return KernelSignature(
73-
"full_sr", {}, {"ShapeTensor", "ValueTensor"}, {"Out"});
73+
"full_sr", {}, {"ShapeTensor", "ValueTensor", "dtype"}, {"Out"});
7474
} else {
7575
const auto& str_value =
7676
paddle::any_cast<std::string>(ctx.Attr("str_value"));
7777
if (str_value.empty()) {
7878
return KernelSignature(
79-
"full_sr", {}, {"ShapeTensor", "value"}, {"Out"});
79+
"full_sr", {}, {"ShapeTensor", "value", "dtype"}, {"Out"});
8080
} else {
8181
return KernelSignature(
82-
"full_sr", {}, {"ShapeTensor", "str_value"}, {"Out"});
82+
"full_sr", {}, {"ShapeTensor", "str_value", "dtype"}, {"Out"});
8383
}
8484
}
8585
} else if (ctx.InputSize("ShapeTensorList") > 0) {
8686
if (ctx.HasInput("ValueTensor")) {
87-
return KernelSignature(
88-
"full_sr", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"});
87+
return KernelSignature("full_sr",
88+
{},
89+
{"ShapeTensorList", "ValueTensor", "dtype"},
90+
{"Out"});
8991
} else {
9092
const auto& str_value =
9193
paddle::any_cast<std::string>(ctx.Attr("str_value"));
9294
if (str_value.empty()) {
9395
return KernelSignature(
94-
"full_sr", {}, {"ShapeTensorList", "value"}, {"Out"});
96+
"full_sr", {}, {"ShapeTensorList", "value", "dtype"}, {"Out"});
9597
} else {
96-
return KernelSignature(
97-
"full_sr", {}, {"ShapeTensorList", "str_value"}, {"Out"});
98+
return KernelSignature("full_sr",
99+
{},
100+
{"ShapeTensorList", "str_value", "dtype"},
101+
{"Out"});
98102
}
99103
}
100104
} else {
101105
if (ctx.HasInput("ValueTensor")) {
102106
return KernelSignature(
103-
"full_sr", {}, {"shape", "ValueTensor"}, {"Out"});
107+
"full_sr", {}, {"shape", "ValueTensor", "dtype"}, {"Out"});
104108
} else {
105109
const auto& str_value =
106110
paddle::any_cast<std::string>(ctx.Attr("str_value"));
107111
if (str_value.empty()) {
108-
return KernelSignature("full_sr", {}, {"shape", "value"}, {"Out"});
112+
return KernelSignature(
113+
"full_sr", {}, {"shape", "value", "dtype"}, {"Out"});
109114
} else {
110115
return KernelSignature(
111-
"full_sr", {}, {"shape", "str_value"}, {"Out"});
116+
"full_sr", {}, {"shape", "str_value", "dtype"}, {"Out"});
112117
}
113118
}
114119
}

0 commit comments

Comments
 (0)