Skip to content

Commit d0c847e

Browse files
authored
[Prim][Pir] Add FLAG_prim_backward_blacklist (#69156)
* add FLAG_prim_backward_blacklist
1 parent ba8fbba commit d0c847e

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

paddle/common/flags.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,14 @@ PHI_DEFINE_EXPORTED_string(
16891689
"",
16901690
"It controls the forward blacklist ops not to be decomposed.");
16911691

1692+
// PIR and prim related FLAG
1693+
// Example: If prim_backward_blacklist="relu_grad;mean_grad",
1694+
// it will block the decompsitions of `relu` and `mean` backward grads.
1695+
PHI_DEFINE_EXPORTED_string(
1696+
prim_backward_blacklist,
1697+
"",
1698+
"It controls the forward blacklist ops not to be decomposed.");
1699+
16921700
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
16931701
defined(PADDLE_WITH_XPU_BKCL)
16941702
/**

paddle/fluid/pybind/pybind.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ limitations under the License. */
233233
#endif
234234

235235
COMMON_DECLARE_bool(use_mkldnn);
236+
COMMON_DECLARE_string(prim_backward_blacklist);
236237

237238
// disable auto conversion to list in Python
238239
PYBIND11_MAKE_OPAQUE(phi::TensorArray);
@@ -847,6 +848,25 @@ static std::vector<std::vector<pir::Value>> GenerateBackwardBlockForPyLayerOp(
847848
return res;
848849
}
849850

851+
namespace {
852+
std::unordered_set<std::string> StringSplit(const std::string &str) {
853+
std::istringstream iss(str);
854+
std::unordered_set<std::string> tokens;
855+
std::string token;
856+
while (std::getline(iss, token, ';')) {
857+
size_t startpos = token.find_first_not_of(' ');
858+
size_t endpos = token.find_last_not_of(' ');
859+
if ((startpos != std::string::npos) && (endpos != std::string::npos)) {
860+
token = token.substr(startpos, endpos - startpos + 1);
861+
} else if (startpos != std::string::npos) {
862+
token = token.substr(startpos);
863+
}
864+
tokens.insert(token);
865+
}
866+
return tokens;
867+
}
868+
} // namespace
869+
850870
void BindVjp(pybind11::module *m) {
851871
m->def(
852872
"call_vjp",
@@ -878,6 +898,10 @@ void BindVjp(pybind11::module *m) {
878898
common::errors::InvalidArgument(
879899
"The vjp function is not registered in %s op ",
880900
fwd_op.name()));
901+
const std::unordered_set<std::string> backward_blacklist_ops =
902+
StringSplit(FLAGS_prim_backward_blacklist);
903+
paddle::prim::PrimCommonUtils::SetPrimBackwardBlacklist(
904+
backward_blacklist_ops);
881905
vjp_res = vjp_interface.Vjp(
882906
&fwd_op, inputs, outputs, out_grads, stop_gradients);
883907
}

test/prim/pir_prim/test_pir_prim_flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ def test_prim_backward_blacklist(self):
133133
self.train()
134134
core._set_prim_all_enabled(False)
135135

136+
def test_prim_backward_blacklist_flag(self):
137+
core._set_prim_all_enabled(True)
138+
paddle.set_flags(
139+
{"FLAGS_prim_backward_blacklist": "tanh_grad;exp_grad"}
140+
)
141+
self.train()
142+
core._set_prim_all_enabled(False)
143+
136144

137145
if __name__ == '__main__':
138146
unittest.main()

0 commit comments

Comments
 (0)