Skip to content

Commit 4725667

Browse files
committed
test: add tests for py::scoped_critical_section
1 parent 032fef2 commit 4725667

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ set(PYBIND11_TEST_FILES
166166
test_potentially_slicing_weak_ptr
167167
test_python_multiple_inheritance
168168
test_pytypes
169+
test_scoped_critical_section
169170
test_sequences_and_iterators
170171
test_smart_ptr
171172
test_stl
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include <pybind11/critical_section.h>
2+
3+
#include "catch.hpp"
4+
#include "pybind11_tests.h"
5+
6+
#include <atomic>
7+
#include <thread>
8+
9+
#ifdef PYBIND11_CPP20
10+
# include <barrier>
11+
12+
// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874
13+
class BoolWrapper {
14+
public:
15+
explicit BoolWrapper(bool value) : value_{value} {}
16+
bool get() const { return value_.load(std::memory_order_acquire); }
17+
void set(bool value) { value_.store(value, std::memory_order_release); }
18+
19+
private:
20+
std::atomic<bool> value_;
21+
};
22+
23+
void test_scoped_critical_section(py::class_<BoolWrapper> &cls) {
24+
auto barrier = std::barrier(2);
25+
auto bool_wrapper = cls(false);
26+
27+
std::thread t1([&]() {
28+
py::scoped_critical_section lock{bool_wrapper};
29+
barrier.arrive_and_wait();
30+
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
31+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
32+
bw->set(true);
33+
});
34+
35+
std::thread t2([&]() {
36+
barrier.arrive_and_wait();
37+
py::scoped_critical_section lock{bool_wrapper};
38+
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
39+
REQUIRE(bw->get() == true);
40+
});
41+
42+
t1.join();
43+
t2.join();
44+
}
45+
46+
void test_scoped_critical_section2(py::class_<BoolWrapper> &cls) {
47+
auto barrier = std::barrier(3);
48+
auto bool_wrapper1 = cls(false);
49+
auto bool_wrapper2 = cls(false);
50+
51+
std::thread t1([&]() {
52+
py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2};
53+
barrier.arrive_and_wait();
54+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
55+
auto bw1 = bool_wrapper1.cast<std::shared_ptr<BoolWrapper>>();
56+
auto bw2 = bool_wrapper2.cast<std::shared_ptr<BoolWrapper>>();
57+
bw1->set(true);
58+
bw2->set(true);
59+
});
60+
61+
std::thread t2([&]() {
62+
barrier.arrive_and_wait();
63+
py::scoped_critical_section lock{bool_wrapper1};
64+
auto bw1 = bool_wrapper1.cast<std::shared_ptr<BoolWrapper>>();
65+
REQUIRE(bw1->get() == true);
66+
});
67+
68+
std::thread t3([&]() {
69+
barrier.arrive_and_wait();
70+
py::scoped_critical_section lock{bool_wrapper2};
71+
auto bw2 = bool_wrapper2.cast<std::shared_ptr<BoolWrapper>>();
72+
REQUIRE(bw2->get() == true);
73+
});
74+
75+
t1.join();
76+
t2.join();
77+
t3.join();
78+
}
79+
80+
void test_scoped_critical_section2_same_object_no_deadlock(py::class_<BoolWrapper> &cls) {
81+
auto barrier = std::barrier(2);
82+
auto bool_wrapper = cls(false);
83+
84+
std::thread t1([&]() {
85+
py::scoped_critical_section lock{bool_wrapper, bool_wrapper};
86+
barrier.arrive_and_wait();
87+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
88+
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
89+
bw->set(true);
90+
});
91+
92+
std::thread t2([&]() {
93+
barrier.arrive_and_wait();
94+
py::scoped_critical_section lock{bool_wrapper};
95+
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
96+
REQUIRE(bw->get() == true);
97+
});
98+
99+
t1.join();
100+
t2.join();
101+
}
102+
#endif
103+
104+
TEST_SUBMODULE(scoped_critical_section, m) {
105+
m.attr("defined_THREAD_SANITIZER") =
106+
#if defined(THREAD_SANITIZER)
107+
true;
108+
#else
109+
false;
110+
#endif
111+
m.attr("has_barrier") =
112+
#if defined(PYBIND11_CPP20)
113+
true;
114+
#else
115+
false;
116+
#endif
117+
118+
#ifdef PYBIND11_CPP20
119+
auto BoolWrapperClass = py::class_<BoolWrapper>(m, "BoolWrapper")
120+
.def(py::init<bool>())
121+
.def("get", &BoolWrapper::get)
122+
.def("set", &BoolWrapper::set);
123+
124+
m.def("test_scoped_critical_section",
125+
[&]() -> void { test_scoped_critical_section(BoolWrapperClass); });
126+
m.def("test_scoped_critical_section2",
127+
[&]() -> void { test_scoped_critical_section2(BoolWrapperClass); });
128+
m.def("test_scoped_critical_section2_same_object_no_deadlock", [&]() -> void {
129+
test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperClass);
130+
});
131+
#endif
132+
}

tests/test_scoped_critical_section.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pybind11_tests import scoped_critical_section as m
6+
7+
8+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
9+
def test_scoped_critical_section() -> None:
10+
for _ in range(64):
11+
m.test_scoped_critical_section()
12+
13+
14+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
15+
def test_scoped_critical_section2() -> None:
16+
for _ in range(64):
17+
assert m.test_scoped_critical_section2()
18+
19+
20+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
21+
def test_scoped_critical_section2_same_object_no_deadlock() -> None:
22+
for _ in range(64):
23+
m.test_scoped_critical_section2_same_object_no_deadlock()

0 commit comments

Comments
 (0)