Skip to content

Commit 3b0e253

Browse files
authored
[SOT][Faster Guard] add FasterStringifiedExpression (#69353)
1 parent ef23e8f commit 3b0e253

File tree

10 files changed

+167
-52
lines changed

10 files changed

+167
-52
lines changed

paddle/fluid/pybind/jit.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ void BindGuard(pybind11::module *m) {
7171
.def(py::init<const py::function &>(), py::arg("guard_check_fn"));
7272
py::class_<GuardGroup, GuardBase, std::shared_ptr<GuardGroup>>(
7373
*m, "GuardGroup", R"DOC(GuardGroup Class.)DOC")
74-
.def(py::init<std::vector<std::shared_ptr<GuardBase>>>(),
74+
.def(py::init<const std::vector<std::shared_ptr<GuardBase>> &>(),
7575
py::arg("guards"));
7676
py::class_<TypeMatchGuard, GuardBase, std::shared_ptr<TypeMatchGuard>>(
7777
*m, "TypeMatchGuard", R"DOC(TypeMatchGuard Class.)DOC")
7878
.def(py::init<const py::type &>(), py::arg("py_type"));
7979
py::class_<LengthMatchGuard, GuardBase, std::shared_ptr<LengthMatchGuard>>(
8080
*m, "LengthMatchGuard", R"DOC(LengthMatchGuard Class.)DOC")
81-
.def(py::init<Py_ssize_t>(), py::arg("length"));
81+
.def(py::init<const Py_ssize_t &>(), py::arg("length"));
8282
py::class_<ValueMatchGuard, GuardBase, std::shared_ptr<ValueMatchGuard>>(
8383
*m, "ValueMatchGuard", R"DOC(ValueMatchGuard Class.)DOC")
8484
.def(py::init<const py::object &>(), py::arg("py_value"));
@@ -90,6 +90,9 @@ void BindGuard(pybind11::module *m) {
9090
py::class_<LayerMatchGuard, GuardBase, std::shared_ptr<LayerMatchGuard>>(
9191
*m, "LayerMatchGuard", R"DOC(LayerMatchGuard Class.)DOC")
9292
.def(py::init<const py::object &>(), py::arg("layer_obj"));
93+
py::class_<ShapeMatchGuard, GuardBase, std::shared_ptr<ShapeMatchGuard>>(
94+
*m, "ShapeMatchGuard", R"DOC(ShapeMatchGuard Class.)DOC")
95+
.def(py::init<const std::vector<py::object> &>(), py::arg("shape"));
9396

9497
m->def(
9598
"merge_guard",

paddle/fluid/pybind/sot/guards.cc

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/pybind/sot/guards.h"
16+
#include "paddle/phi/api/include/tensor.h"
1617

1718
#if SOT_IS_SUPPORTED
1819

@@ -25,8 +26,16 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
2526
}
2627
#endif
2728

29+
std::optional<paddle::Tensor> GetTensorFromPyObject(PyObject* obj) {
30+
if (!paddle::pybind::PyCheckTensor(obj)) {
31+
// TODO(zrr1999): PyCheckTensor only check if the object is a p_tensor_type.
32+
return std::nullopt;
33+
}
34+
return reinterpret_cast<paddle::pybind::TensorObject*>(obj)->tensor;
35+
}
36+
2837
bool LambdaGuard::check(PyObject* value) {
29-
PyObject* x = PyObject_CallOneArg(_guard_check_fn, value);
38+
PyObject* x = PyObject_CallOneArg(guard_check_fn_, value);
3039
if (x == nullptr) {
3140
PyErr_Clear();
3241
return false;
@@ -37,7 +46,7 @@ bool LambdaGuard::check(PyObject* value) {
3746
}
3847

3948
bool GuardGroup::check(PyObject* value) {
40-
for (auto& guard : _guards) {
49+
for (auto& guard : guards_) {
4150
if (!guard->check(value)) {
4251
return false;
4352
}
@@ -46,17 +55,17 @@ bool GuardGroup::check(PyObject* value) {
4655
}
4756

4857
bool TypeMatchGuard::check(PyObject* value) {
49-
return Py_TYPE(value) == _expected;
58+
return Py_TYPE(value) == expected_;
5059
}
5160

5261
bool ValueMatchGuard::check(PyObject* value) {
53-
if (value == _expected_value) {
62+
if (value == expected_value_) {
5463
return true;
5564
}
56-
if (Py_TYPE(value) != _expected_type) {
65+
if (Py_TYPE(value) != expected_type_) {
5766
return false;
5867
}
59-
int result = PyObject_RichCompareBool(value, _expected_value, Py_EQ);
68+
int result = PyObject_RichCompareBool(value, expected_value_, Py_EQ);
6069
// Check for exception
6170
if (result == -1) {
6271
PyErr_Clear();
@@ -66,25 +75,41 @@ bool ValueMatchGuard::check(PyObject* value) {
6675
}
6776

6877
bool LengthMatchGuard::check(PyObject* value) {
69-
return PySequence_Size(value) == _expected;
78+
return PySequence_Size(value) == expected_;
7079
}
7180

7281
bool DtypeMatchGuard::check(PyObject* value) {
73-
if (!paddle::pybind::PyCheckTensor(value)) {
74-
// TODO(zrr1999): PyCheckTensor only check if the object is a p_tensor_type.
82+
auto tensor = GetTensorFromPyObject(value);
83+
if (!tensor) {
7584
return false;
7685
}
77-
auto dtype =
78-
reinterpret_cast<paddle::pybind::TensorObject*>(value)->tensor.type();
79-
return phi::TransToProtoVarType(dtype) == _expected;
86+
auto dtype = tensor->type();
87+
return phi::TransToProtoVarType(dtype) == expected_;
88+
}
89+
90+
bool ShapeMatchGuard::check(PyObject* value) {
91+
auto tensor = GetTensorFromPyObject(value);
92+
if (!tensor) {
93+
return false;
94+
}
95+
auto shape = tensor->shape();
96+
if (shape.size() != expected_.size()) {
97+
return false;
98+
}
99+
for (size_t i = 0; i < shape.size(); ++i) {
100+
if (expected_[i] && shape[i] != *expected_[i]) {
101+
return false;
102+
}
103+
}
104+
return true;
80105
}
81106

82107
bool LayerMatchGuard::check(PyObject* value) {
83-
if (value != _layer_ptr) {
108+
if (value != layer_ptr_) {
84109
return false;
85110
}
86111
PyObject* training = PyObject_GetAttrString(value, "training");
87-
return (training == Py_True) == _training;
112+
return (training == Py_True) == training_;
88113
}
89114

90115
#endif

paddle/fluid/pybind/sot/guards.h

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616
#include <Python.h>
1717
#include "paddle/fluid/framework/data_type.h"
1818
#include "paddle/fluid/pybind/sot/macros.h"
19-
#include "paddle/phi/core/framework/heter_service.pb.h"
2019
#include "paddle/phi/core/utils/data_type.h"
2120
#include "paddle/utils/pybind.h"
2221
#include "pybind11/pybind11.h"
@@ -38,111 +37,131 @@ class GuardBase {
3837
class LambdaGuard : public GuardBase {
3938
public:
4039
explicit LambdaGuard(PyObject* guard_check_fn)
41-
: _guard_check_fn(guard_check_fn) {}
40+
: guard_check_fn_(guard_check_fn) {}
4241

4342
explicit LambdaGuard(const py::function& guard_check_fn)
44-
: _guard_check_fn(guard_check_fn.ptr()) {
45-
Py_INCREF(_guard_check_fn);
43+
: guard_check_fn_(guard_check_fn.ptr()) {
44+
Py_INCREF(guard_check_fn_);
4645
}
4746

48-
~LambdaGuard() { Py_DECREF(_guard_check_fn); }
47+
~LambdaGuard() { Py_DECREF(guard_check_fn_); }
4948

5049
bool check(PyObject* value);
5150

5251
private:
53-
PyObject* _guard_check_fn;
52+
PyObject* guard_check_fn_;
5453
};
5554

5655
class GuardGroup : public GuardBase {
5756
public:
58-
explicit GuardGroup(std::vector<std::shared_ptr<GuardBase>> guards) {
57+
explicit GuardGroup(const std::vector<std::shared_ptr<GuardBase>>& guards) {
5958
for (auto& guard : guards) {
6059
if (auto group = dynamic_cast<GuardGroup*>(guard.get())) {
61-
_guards.insert(
62-
_guards.end(), group->_guards.begin(), group->_guards.end());
60+
guards_.insert(
61+
guards_.end(), group->guards_.begin(), group->guards_.end());
6362
} else {
64-
_guards.push_back(std::move(guard));
63+
guards_.push_back(std::move(guard));
6564
}
6665
}
6766
}
6867
bool check(PyObject* value);
6968

7069
private:
71-
std::vector<std::shared_ptr<GuardBase>> _guards;
70+
std::vector<std::shared_ptr<GuardBase>> guards_;
7271
};
7372

7473
class TypeMatchGuard : public GuardBase {
7574
public:
7675
explicit TypeMatchGuard(PyObject* type_ptr)
77-
: _expected(reinterpret_cast<PyTypeObject*>(type_ptr)) {}
76+
: expected_(reinterpret_cast<PyTypeObject*>(type_ptr)) {}
7877

7978
explicit TypeMatchGuard(const py::type& py_type)
80-
: _expected(reinterpret_cast<PyTypeObject*>(py_type.ptr())) {}
79+
: expected_(reinterpret_cast<PyTypeObject*>(py_type.ptr())) {}
8180

8281
bool check(PyObject* value);
8382

8483
private:
85-
PyTypeObject* _expected;
84+
PyTypeObject* expected_;
8685
};
8786

8887
class ValueMatchGuard : public GuardBase {
8988
public:
9089
explicit ValueMatchGuard(PyObject* value_ptr)
91-
: _expected_value(value_ptr), _expected_type(value_ptr->ob_type) {}
90+
: expected_value_(value_ptr), expected_type_(value_ptr->ob_type) {}
9291

9392
explicit ValueMatchGuard(const py::object& py_value)
94-
: _expected_value(py_value.ptr()),
95-
_expected_type(Py_TYPE(py_value.ptr())) {
96-
Py_INCREF(_expected_value);
93+
: expected_value_(py_value.ptr()),
94+
expected_type_(Py_TYPE(py_value.ptr())) {
95+
Py_INCREF(expected_value_);
9796
}
9897

99-
~ValueMatchGuard() { Py_DECREF(_expected_value); }
98+
~ValueMatchGuard() { Py_DECREF(expected_value_); }
10099

101100
bool check(PyObject* value);
102101

103102
private:
104-
PyObject* _expected_value;
105-
PyTypeObject* _expected_type;
103+
PyObject* expected_value_;
104+
PyTypeObject* expected_type_;
106105
};
107106

108107
class LengthMatchGuard : public GuardBase {
109108
public:
110-
explicit LengthMatchGuard(Py_ssize_t length) : _expected(length) {}
109+
explicit LengthMatchGuard(const Py_ssize_t& length) : expected_(length) {}
111110

112111
bool check(PyObject* value);
113112

114113
private:
115-
Py_ssize_t _expected;
114+
Py_ssize_t expected_;
116115
};
117116

118117
class DtypeMatchGuard : public GuardBase {
119118
public:
120119
explicit DtypeMatchGuard(const paddle::framework::proto::VarType& dtype_ptr)
121-
: _expected(dtype_ptr.type()) {}
120+
: expected_(dtype_ptr.type()) {}
122121

123122
explicit DtypeMatchGuard(const phi::DataType& dtype_ptr)
124-
: _expected(phi::TransToProtoVarType(dtype_ptr)) {}
123+
: expected_(phi::TransToProtoVarType(dtype_ptr)) {}
125124

126125
bool check(PyObject* value);
127126

128127
private:
129-
int _expected;
128+
int expected_;
129+
};
130+
131+
class ShapeMatchGuard : public GuardBase {
132+
public:
133+
explicit ShapeMatchGuard(const std::vector<std::optional<int64_t>>& shape)
134+
: expected_(shape) {}
135+
136+
explicit ShapeMatchGuard(const std::vector<py::object>& shape) {
137+
expected_.resize(shape.size());
138+
for (size_t i = 0; i < shape.size(); ++i) {
139+
if (py::isinstance<py::int_>(shape[i]) && shape[i].cast<int64_t>() > 0) {
140+
expected_[i] = std::make_optional(shape[i].cast<int64_t>());
141+
}
142+
}
143+
}
144+
145+
bool check(PyObject* value);
146+
147+
private:
148+
std::vector<std::optional<int64_t>> expected_;
130149
};
131150

132151
class LayerMatchGuard : public GuardBase {
133152
public:
134-
explicit LayerMatchGuard(PyObject* layer_ptr) : _layer_ptr(layer_ptr) {
135-
_training = PyObject_GetAttrString(layer_ptr, "training") == Py_True;
153+
explicit LayerMatchGuard(PyObject* layer_ptr) : layer_ptr_(layer_ptr) {
154+
training_ = PyObject_GetAttrString(layer_ptr, "training") == Py_True;
136155
}
137156

138157
explicit LayerMatchGuard(const py::object& layer_obj)
139-
: _layer_ptr(layer_obj.ptr()), _training(layer_obj.attr("training")) {}
158+
: layer_ptr_(layer_obj.ptr()), training_(layer_obj.attr("training")) {}
140159

141160
bool check(PyObject* value);
142161

143162
private:
144-
PyObject* _layer_ptr;
145-
bool _training;
163+
PyObject* layer_ptr_;
164+
bool training_;
146165
};
147166

148167
#endif

python/paddle/jit/sot/opcode_translator/executor/guard.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,19 @@
2222
import paddle
2323

2424
from ...profiler import EventGuard
25-
from ...utils import current_symbol_registry, log, log_do
25+
from ...utils import (
26+
ENV_SOT_ENABLE_FASTER_GUARD,
27+
current_symbol_registry,
28+
log,
29+
log_do,
30+
)
2631

2732
Guard = Callable[[types.FrameType], bool]
2833

2934
if TYPE_CHECKING:
3035
from .variables import VariableBase
3136

37+
GuardBase = paddle.framework.core.GuardBase
3238
CheckGuardInputT = TypeVar("CheckGuardInputT", bound=VariableBase)
3339

3440
# NOTE(SigureMo): [How to write Stringified Guard?]
@@ -83,6 +89,33 @@ def __hash__(self):
8389
return hash(self.inlined_expr)
8490

8591

92+
class FasterStringifiedExpression(StringifiedExpression):
93+
def __init__(
94+
self,
95+
expr_template: str,
96+
faster_guard: GuardBase,
97+
sub_exprs: list[StringifiedExpression],
98+
free_vars: dict[str, Any],
99+
):
100+
self.faster_guard = faster_guard
101+
if ENV_SOT_ENABLE_FASTER_GUARD:
102+
original_expr_template = expr_template
103+
guard_cls_name = faster_guard.__class__.__name__
104+
guard_name = f"{guard_cls_name}_{id(faster_guard)}"
105+
expr_template = (
106+
guard_name + "(" + ", ".join(["{}"] * len(sub_exprs)) + ")"
107+
)
108+
free_vars = union_free_vars(
109+
free_vars, {guard_name: faster_guard.check}
110+
)
111+
log(
112+
3,
113+
f"[FasterGuard]: transform {original_expr_template} to {expr_template}\n",
114+
)
115+
116+
super().__init__(expr_template, sub_exprs, free_vars)
117+
118+
86119
def union_free_vars(*free_vars: dict[str, Any]):
87120
return {k: v for d in free_vars for k, v in d.items()}
88121

@@ -132,7 +165,7 @@ def support_weak_ref(obj):
132165

133166

134167
def check_guard(
135-
fn: Callable[[CheckGuardInputT], list[StringifiedExpression]]
168+
fn: Callable[[CheckGuardInputT], list[StringifiedExpression]],
136169
) -> Callable[[CheckGuardInputT], list[StringifiedExpression]]:
137170
def wrapper(self: CheckGuardInputT) -> list[StringifiedExpression]:
138171
assert (

0 commit comments

Comments
 (0)