Skip to content

Commit

Permalink
[SOT] Compare float in guard with a threshold (#70510)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Dec 27, 2024
1 parent 19c63a2 commit ccbc242
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 1 deletion.
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ void BindGuard(pybind11::module *m) {
py::class_<LengthMatchGuard, GuardBase, std::shared_ptr<LengthMatchGuard>>(
*m, "LengthMatchGuard", R"DOC(LengthMatchGuard Class.)DOC")
.def(py::init<const Py_ssize_t &>(), py::arg("length"));
py::class_<FloatCloseGuard, GuardBase, std::shared_ptr<FloatCloseGuard>>(
*m, "FloatCloseGuard", R"DOC(FloatCloseGuard Class.)DOC")
.def(py::init<const double, const double>(),
py::arg("value"),
py::arg("epsilon"));
py::class_<ValueMatchGuard, GuardBase, std::shared_ptr<ValueMatchGuard>>(
*m, "ValueMatchGuard", R"DOC(ValueMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("py_value"));
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ bool TypeMatchGuard::check(PyObject* value) {

bool IdMatchGuard::check(PyObject* value) { return value == expected_; }

bool FloatCloseGuard::check(PyObject* value) {
if (Py_TYPE(value) != &PyFloat_Type) {
return false;
}
double v = reinterpret_cast<PyFloatObject*>(value)->ob_fval;
return std::abs(v - expected_) < 1e-13;
}

bool ValueMatchGuard::check(PyObject* value) {
return PyObject_Equal(value, expected_value_);
}
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ class ValueMatchGuard : public GuardBase {
PyTypeObject* expected_type_;
};

class FloatCloseGuard : public GuardBase {
public:
explicit FloatCloseGuard(double value, double epsilon)
: expected_(value), epsilon_(epsilon) {}

bool check(PyObject* value);

private:
double expected_;
double epsilon_;
};

class LengthMatchGuard : public GuardBase {
public:
explicit LengthMatchGuard(const Py_ssize_t& length) : expected_(length) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,24 @@ def chr(self):
DummyTracker([self]),
)

@check_guard
def make_stringified_guard(self) -> list[StringifiedExpression]:
if self.get_py_type() is not float:
return super().make_stringified_guard()

frame_value_tracer = self.tracker.trace_value_from_frame()
epsilon = 1e-13
return [
FasterStringifiedExpression(
f"type({{0}}) is float and abs({self.get_py_value()!r} - {{0}}) < {epsilon}",
paddle.framework.core.FloatCloseGuard(
self.get_py_value(), epsilon
),
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars),
)
]

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if type(value) in ConstTypes:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def collect_subgraph_relation(self, inputs, outputs, partial_program_layer):
input_shape_infos,
output_shape_infos,
self.is_first_call,
self.graph_size(),
)

def __call__(self, *args, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/jit/sot/utils/info_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ def __init__(
input_shape_infos: list[SubGraphRelationInfo.ConcreteShapeInfo],
output_shape_infos: list[SubGraphRelationInfo.ConcreteShapeInfo],
is_first_call: bool,
graph_size: int,
):
super().__init__()
self.subgraph_name = subgraph_name
self.input_shape_infos = input_shape_infos
self.output_shape_infos = output_shape_infos
self.is_first_call = is_first_call
self.graph_size = graph_size

@classmethod
def summary(cls, history: list[Self]) -> str:
Expand Down Expand Up @@ -169,7 +171,7 @@ def to_tensor_node_name(
subgraph_id = f"subgraph_{i}"
dot.node(
subgraph_id,
f"Subgraph {i} ({info.subgraph_name})",
f"Subgraph {i} ({info.subgraph_name}, size={info.graph_size})",
shape='oval',
fillcolor='cyan' if info.is_first_call else None,
style='filled' if info.is_first_call else None,
Expand Down
8 changes: 8 additions & 0 deletions test/sot/test_faster_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def test_range_match_guard(self):
self.assertTrue(guard_range.check(range(1, 10, 2)))
self.assertFalse(guard_range.check(range(11)))

def test_float_close_guard(self):
expected = 0.018181818181818184
epsilon = 1e-13
guard_float = paddle.framework.core.FloatCloseGuard(expected, epsilon)
self.assertTrue(guard_float.check(0.018181818181818184))
self.assertTrue(guard_float.check(0.018181818181818177))
self.assertFalse(guard_float.check(0.018181818191818184))


if __name__ == "__main__":
unittest.main()

0 comments on commit ccbc242

Please sign in to comment.