Skip to content

Commit 76e973c

Browse files
committed
Add query_intersections() method for efficient AABB pair detection
Implements a new query_intersections() method that efficiently finds all pairs of intersecting bounding boxes in the tree, addressing the feature request in issue #46. Key features: - Returns numpy array of shape (n_pairs, 2) with index pairs (i, j) where i < j - Parallel processing using std::thread for improved performance - Automatic double-precision refinement when exact coordinates are available - No duplicate pairs or self-pairs - Similar to scipy.spatial.cKDTree.query_pairs but for AABBs This eliminates the need for manual post-processing of batch_query results using np.vectorize, np.repeat, and np.concatenate, which previously canceled out the performance gains from C++ parallelization. Changes: - cpp/prtree.h: Added query_intersections() method to PRTree class - cpp/main.cc: Added Python bindings for all dimensions (2D, 3D, 4D) - tests/test_PRTree.py: Added comprehensive tests including edge cases - README.md: Updated documentation with usage examples All tests pass (45 new tests + existing tests). Resolves #46
1 parent 699c912 commit 76e973c

File tree

4 files changed

+373
-0
lines changed

4 files changed

+373
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ _python_prtree_ is a python/c++ implementation of the Priority R-Tree (see refer
99
- `query` and `batch_query`
1010
- `batch_query` is parallelized by `std::thread` and is much faster than the `query` method.
1111
- The `query` method has an optional keyword argument `return_obj`; if `return_obj=True`, a Python object is returned.
12+
- `query_intersections`
13+
- Returns all pairs of intersecting AABBs as a numpy array of shape (n_pairs, 2).
14+
- Optimized for performance with parallel processing and double-precision refinement.
15+
- Similar to `scipy.spatial.cKDTree.query_pairs` but for bounding boxes instead of points.
1216
- `rebuild`
1317
- It improves performance when many insert/delete operations are called since the last rebuild.
1418
- Note that if the size changes more than 1.5 times, the insert/erase method also performs `rebuild`.
@@ -77,6 +81,11 @@ print(prtree.query([0.5, 0.5]))
7781
# [1]
7882
print(prtree.query(0.5, 0.5)) # 1d-array
7983
# [1]
84+
85+
# Find all pairs of intersecting rectangles
86+
pairs = prtree.query_intersections()
87+
print(pairs)
88+
# [[1 3]] # rectangles with index 1 and 3 intersect
8089
```
8190

8291
```python

cpp/main.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ PYBIND11_MODULE(PRTree, m)
6262
)pbdoc")
6363
.def("size", &PRTree<T, B, 2>::size, R"pbdoc(
6464
get n
65+
)pbdoc")
66+
.def("query_intersections", &PRTree<T, B, 2>::query_intersections, R"pbdoc(
67+
Find all pairs of intersecting AABBs.
68+
Returns a numpy array of shape (n_pairs, 2) where each row contains
69+
a pair of indices (i, j) with i < j representing intersecting AABBs.
6570
)pbdoc");
6671

6772
py::class_<PRTree<T, B, 3>>(m, "_PRTree3D")
@@ -109,6 +114,11 @@ PYBIND11_MODULE(PRTree, m)
109114
)pbdoc")
110115
.def("size", &PRTree<T, B, 3>::size, R"pbdoc(
111116
get n
117+
)pbdoc")
118+
.def("query_intersections", &PRTree<T, B, 3>::query_intersections, R"pbdoc(
119+
Find all pairs of intersecting AABBs.
120+
Returns a numpy array of shape (n_pairs, 2) where each row contains
121+
a pair of indices (i, j) with i < j representing intersecting AABBs.
112122
)pbdoc");
113123

114124
py::class_<PRTree<T, B, 4>>(m, "_PRTree4D")
@@ -156,6 +166,11 @@ PYBIND11_MODULE(PRTree, m)
156166
)pbdoc")
157167
.def("size", &PRTree<T, B, 4>::size, R"pbdoc(
158168
get n
169+
)pbdoc")
170+
.def("query_intersections", &PRTree<T, B, 4>::query_intersections, R"pbdoc(
171+
Find all pairs of intersecting AABBs.
172+
Returns a numpy array of shape (n_pairs, 2) where each row contains
173+
a pair of indices (i, j) with i < j representing intersecting AABBs.
159174
)pbdoc");
160175

161176
#ifdef VERSION_INFO

cpp/prtree.h

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,4 +1626,176 @@ class PRTree
16261626
{
16271627
return static_cast<int64_t>(idx2bb.size());
16281628
}
1629+
1630+
/**
1631+
* Find all pairs of intersecting AABBs in the tree.
1632+
* Returns a numpy array of shape (n_pairs, 2) where each row contains
1633+
* a pair of indices (i, j) with i < j representing intersecting AABBs.
1634+
*
1635+
* This method is optimized for performance by:
1636+
* - Using parallel processing for queries
1637+
* - Avoiding duplicate pairs by enforcing i < j
1638+
* - Performing intersection checks in C++ to minimize Python overhead
1639+
* - Using double-precision refinement when exact coordinates are available
1640+
*
1641+
* @return py::array_t<T> Array of shape (n_pairs, 2) containing index pairs
1642+
*/
1643+
py::array_t<T> query_intersections()
1644+
{
1645+
// Collect all indices and bounding boxes
1646+
vec<T> indices;
1647+
vec<BB<D>> bboxes;
1648+
vec<std::array<double, 2 * D>> exact_coords;
1649+
1650+
if (unlikely(idx2bb.empty()))
1651+
{
1652+
// Return empty array of shape (0, 2)
1653+
vec<T> empty_data;
1654+
std::unique_ptr<vec<T>> data_ptr = std::make_unique<vec<T>>(std::move(empty_data));
1655+
auto capsule = py::capsule(data_ptr.get(), [](void *p)
1656+
{ std::unique_ptr<vec<T>>(reinterpret_cast<vec<T> *>(p)); });
1657+
data_ptr.release();
1658+
return py::array_t<T>({0, 2}, {2 * sizeof(T), sizeof(T)}, nullptr, capsule);
1659+
}
1660+
1661+
indices.reserve(idx2bb.size());
1662+
bboxes.reserve(idx2bb.size());
1663+
exact_coords.reserve(idx2bb.size());
1664+
1665+
for (const auto &pair : idx2bb)
1666+
{
1667+
indices.push_back(pair.first);
1668+
bboxes.push_back(pair.second);
1669+
1670+
// Get exact coordinates if available
1671+
auto it = idx2exact.find(pair.first);
1672+
if (it != idx2exact.end())
1673+
{
1674+
exact_coords.push_back(it->second);
1675+
}
1676+
else
1677+
{
1678+
// Create dummy exact coords from float32 BB (won't be used for refinement)
1679+
std::array<double, 2 * D> dummy;
1680+
for (int i = 0; i < D; ++i)
1681+
{
1682+
dummy[i] = static_cast<double>(pair.second.min(i));
1683+
dummy[i + D] = static_cast<double>(pair.second.max(i));
1684+
}
1685+
exact_coords.push_back(dummy);
1686+
}
1687+
}
1688+
1689+
const size_t n_items = indices.size();
1690+
1691+
// Use thread-local storage to collect pairs
1692+
const size_t n_threads = std::min(static_cast<size_t>(std::thread::hardware_concurrency()), n_items);
1693+
vec<vec<std::pair<T, T>>> thread_pairs(n_threads);
1694+
1695+
#ifdef MY_PARALLEL
1696+
vec<std::thread> threads;
1697+
threads.reserve(n_threads);
1698+
1699+
for (size_t t = 0; t < n_threads; ++t)
1700+
{
1701+
threads.emplace_back([&, t]()
1702+
{
1703+
vec<std::pair<T, T>> local_pairs;
1704+
1705+
for (size_t i = t; i < n_items; i += n_threads)
1706+
{
1707+
const T idx_i = indices[i];
1708+
const BB<D> &bb_i = bboxes[i];
1709+
1710+
// Find all intersections with this bounding box
1711+
auto candidates = find(bb_i);
1712+
1713+
// Refine candidates using exact coordinates if available
1714+
if (!idx2exact.empty())
1715+
{
1716+
candidates = refine_candidates(candidates, exact_coords[i]);
1717+
}
1718+
1719+
// Keep only pairs where idx_i < idx_j to avoid duplicates
1720+
for (const T &idx_j : candidates)
1721+
{
1722+
if (idx_i < idx_j)
1723+
{
1724+
local_pairs.emplace_back(idx_i, idx_j);
1725+
}
1726+
}
1727+
}
1728+
1729+
thread_pairs[t] = std::move(local_pairs);
1730+
});
1731+
}
1732+
1733+
for (auto &thread : threads)
1734+
{
1735+
thread.join();
1736+
}
1737+
#else
1738+
// Single-threaded version
1739+
vec<std::pair<T, T>> local_pairs;
1740+
1741+
for (size_t i = 0; i < n_items; ++i)
1742+
{
1743+
const T idx_i = indices[i];
1744+
const BB<D> &bb_i = bboxes[i];
1745+
1746+
// Find all intersections with this bounding box
1747+
auto candidates = find(bb_i);
1748+
1749+
// Refine candidates using exact coordinates if available
1750+
if (!idx2exact.empty())
1751+
{
1752+
candidates = refine_candidates(candidates, exact_coords[i]);
1753+
}
1754+
1755+
// Keep only pairs where idx_i < idx_j to avoid duplicates
1756+
for (const T &idx_j : candidates)
1757+
{
1758+
if (idx_i < idx_j)
1759+
{
1760+
local_pairs.emplace_back(idx_i, idx_j);
1761+
}
1762+
}
1763+
}
1764+
1765+
thread_pairs[0] = std::move(local_pairs);
1766+
#endif
1767+
1768+
// Merge results from all threads into a flat vector
1769+
vec<T> flat_pairs;
1770+
size_t total_pairs = 0;
1771+
for (const auto &pairs : thread_pairs)
1772+
{
1773+
total_pairs += pairs.size();
1774+
}
1775+
flat_pairs.reserve(total_pairs * 2);
1776+
1777+
for (const auto &pairs : thread_pairs)
1778+
{
1779+
for (const auto &pair : pairs)
1780+
{
1781+
flat_pairs.push_back(pair.first);
1782+
flat_pairs.push_back(pair.second);
1783+
}
1784+
}
1785+
1786+
// Create output numpy array using the same pattern as as_pyarray
1787+
auto data = flat_pairs.data();
1788+
std::unique_ptr<vec<T>> data_ptr = std::make_unique<vec<T>>(std::move(flat_pairs));
1789+
auto capsule = py::capsule(data_ptr.get(), [](void *p)
1790+
{ std::unique_ptr<vec<T>>(reinterpret_cast<vec<T> *>(p)); });
1791+
data_ptr.release();
1792+
1793+
// Return 2D array with shape (total_pairs, 2)
1794+
return py::array_t<T>(
1795+
{static_cast<py::ssize_t>(total_pairs), py::ssize_t(2)}, // shape
1796+
{2 * sizeof(T), sizeof(T)}, // strides (row-major)
1797+
data, // data pointer
1798+
capsule // capsule for cleanup
1799+
);
1800+
}
16291801
};

0 commit comments

Comments
 (0)