Skip to content

Commit 7c31c23

Browse files
authored
Keep minimum weight edge when merging parallel boundary edges in conversion from UserGraph to MatchingGraph/SearchGraph (#86)
* Ensure that parallel edge with minimum weight is kept when merging UserGraph boundary nodes results in parallel boundary edges. Fixes #81. * Update pm::UserGraph::to_search_graph to match parallel boundary edge behaviour of pm::UserGraph::to_matching_graph * Refactor by adding UserGraph::to_matching_or_search_graph_helper * Bump stim version in ci * Unpin ninja version in ci * Remove matrix.python-version in ci * Add setup-python action to build_wheels in ci * Update cibuildwheel to v2.16.5 * Specify python-version in pip_install ci * Add back macosx deployment target in build_wheels in ci
1 parent c30fcce commit 7c31c23

File tree

4 files changed

+78
-24
lines changed

4 files changed

+78
-24
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
CIBW_BUILD: "${{ matrix.os_dist.dist }}"
3535
CIBW_ARCHS_MACOS: "x86_64 universal2 arm64"
3636
CIBW_BEFORE_BUILD: pip install --upgrade ninja
37-
CIBW_TEST_REQUIRES: pytest stim~=1.10.dev1666411378
37+
CIBW_TEST_REQUIRES: pytest stim
3838
CIBW_TEST_COMMAND: pytest {project}/tests
3939
strategy:
4040
fail-fast: false
@@ -116,14 +116,16 @@ jobs:
116116
- uses: actions/checkout@v3
117117
with:
118118
submodules: true
119+
120+
- uses: actions/setup-python@v4
119121

120122
- name: Install g++
121123
if: runner.os == 'Linux'
122124
run: |
123125
sudo apt update
124126
sudo apt install gcc-10 g++-10
125127
126-
- uses: pypa/cibuildwheel@v2.16.4
128+
- uses: pypa/cibuildwheel@v2.16.5
127129

128130
- name: Verify clean directory
129131
run: git diff --exit-code
@@ -179,7 +181,7 @@ jobs:
179181
fail-fast: false
180182
matrix:
181183
platform: [windows-latest, macos-latest, ubuntu-latest]
182-
python-version: ["3.10"]
184+
python-version: ["3.11"]
183185

184186
runs-on: ${{ matrix.platform }}
185187

@@ -193,16 +195,16 @@ jobs:
193195
python-version: ${{ matrix.python-version }}
194196

195197
- name: Add requirements
196-
run: python -m pip install --upgrade cmake>=3.12 ninja==1.10.2.4 pytest flake8 pytest-cov
198+
run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov setuptools
197199

198200
- name: Build and install
199-
run: pip install --verbose -e .
201+
run: python -m pip install --verbose -e .
200202

201203
- name: Test without stim
202204
run: python -m pytest tests
203205

204206
- name: Add stim
205-
run: python -m pip install stim~=1.10.dev1666411378
207+
run: python -m pip install stim
206208

207209
- name: Test with stim using coverage
208210
run: python -m pytest tests --cov=./src/pymatching --cov-report term
@@ -218,8 +220,6 @@ jobs:
218220
submodules: true
219221

220222
- uses: actions/setup-python@v4
221-
with:
222-
python-version: ${{ matrix.python-version }}
223223

224224
- name: Install pandoc
225225
run: |
@@ -244,7 +244,7 @@ jobs:
244244
with:
245245
python-version: '3.10'
246246
- name: Add requirements
247-
run: python -m pip install --upgrade cmake>=3.12 ninja==1.10.2.4 pytest flake8 pytest-cov stim~=1.10.dev1666411378
247+
run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov stim
248248
- name: Build and install
249249
run: pip install --verbose -e .
250250
- name: Run tests and collect coverage

src/pymatching/sparse_blossom/driver/user_graph.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -245,18 +245,17 @@ double pm::UserGraph::max_abs_weight() {
245245

246246
pm::MatchingGraph pm::UserGraph::to_matching_graph(pm::weight_int num_distinct_weights) {
247247
pm::MatchingGraph matching_graph(nodes.size(), _num_observables);
248-
double normalising_constant = iter_discretized_edges(
248+
249+
double normalising_constant = to_matching_or_search_graph_helper(
249250
num_distinct_weights,
250251
[&](size_t u, size_t v, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
251252
matching_graph.add_edge(u, v, weight, observables);
252253
},
253254
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
254-
// Only add the boundary edge if it already isn't present. Ideally parallel edges should already have been
255-
// merged, however we are implicitly merging all boundary nodes in this step, which could give rise to new
256-
// parallel edges.
257-
if (matching_graph.nodes[u].neighbors.empty() || matching_graph.nodes[u].neighbors[0])
258-
matching_graph.add_boundary_edge(u, weight, observables);
259-
});
255+
matching_graph.add_boundary_edge(u, weight, observables);
256+
}
257+
);
258+
260259
matching_graph.normalising_constant = normalising_constant;
261260
if (boundary_nodes.size() > 0) {
262261
matching_graph.is_user_graph_boundary_node.clear();
@@ -270,18 +269,16 @@ pm::MatchingGraph pm::UserGraph::to_matching_graph(pm::weight_int num_distinct_w
270269
pm::SearchGraph pm::UserGraph::to_search_graph(pm::weight_int num_distinct_weights) {
271270
/// Identical to to_matching_graph but for constructing a pm::SearchGraph
272271
pm::SearchGraph search_graph(nodes.size());
273-
iter_discretized_edges(
272+
273+
to_matching_or_search_graph_helper(
274274
num_distinct_weights,
275275
[&](size_t u, size_t v, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
276276
search_graph.add_edge(u, v, weight, observables);
277277
},
278278
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
279-
// Only add the boundary edge if it already isn't present. Ideally parallel edges should already have been
280-
// merged, however we are implicitly merging all boundary nodes in this step, which could give rise to new
281-
// parallel edges.
282-
if (search_graph.nodes[u].neighbors.empty() || search_graph.nodes[u].neighbors[0])
283-
search_graph.add_boundary_edge(u, weight, observables);
284-
});
279+
search_graph.add_boundary_edge(u, weight, observables);
280+
}
281+
);
285282
return search_graph;
286283
}
287284

src/pymatching/sparse_blossom/driver/user_graph.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class UserGraph {
9999
pm::weight_int num_distinct_weights,
100100
const EdgeCallable& edge_func,
101101
const BoundaryEdgeCallable& boundary_edge_func);
102+
template <typename EdgeCallable, typename BoundaryEdgeCallable>
103+
double to_matching_or_search_graph_helper(
104+
pm::weight_int num_distinct_weights,
105+
const EdgeCallable& edge_func,
106+
const BoundaryEdgeCallable& boundary_edge_func);
102107
pm::MatchingGraph to_matching_graph(pm::weight_int num_distinct_weights);
103108
pm::SearchGraph to_search_graph(pm::weight_int num_distinct_weights);
104109
pm::Mwpm to_mwpm(pm::weight_int num_distinct_weights, bool ensure_search_graph_included);
@@ -120,7 +125,6 @@ inline double UserGraph::iter_discretized_edges(
120125
pm::weight_int num_distinct_weights,
121126
const EdgeCallable& edge_func,
122127
const BoundaryEdgeCallable& boundary_edge_func) {
123-
pm::MatchingGraph matching_graph(nodes.size(), _num_observables);
124128
double normalising_constant = get_edge_weight_normalising_constant(num_distinct_weights);
125129

126130
for (auto& e : edges) {
@@ -141,6 +145,38 @@ inline double UserGraph::iter_discretized_edges(
141145
return normalising_constant * 2;
142146
}
143147

148+
template <typename EdgeCallable, typename BoundaryEdgeCallable>
149+
inline double UserGraph::to_matching_or_search_graph_helper(
150+
pm::weight_int num_distinct_weights,
151+
const EdgeCallable& edge_func,
152+
const BoundaryEdgeCallable& boundary_edge_func) {
153+
154+
// Use vectors to store boundary edges initially before adding them to the graph, so
155+
// that parallel boundary edges with negative edge weights can be handled correctly
156+
std::vector<bool> has_boundary_edge(nodes.size(), false);
157+
std::vector<pm::signed_weight_int> boundary_edge_weights(nodes.size());
158+
std::vector<std::vector<size_t>> boundary_edge_observables(nodes.size());
159+
160+
double normalising_constant = iter_discretized_edges(
161+
num_distinct_weights,
162+
edge_func,
163+
[&](size_t u, pm::signed_weight_int weight, const std::vector<size_t>& observables) {
164+
// For parallel boundary edges, keep the boundary edge with the smaller weight
165+
if (!has_boundary_edge[u] || boundary_edge_weights[u] > weight){
166+
boundary_edge_weights[u] = weight;
167+
boundary_edge_observables[u] = observables;
168+
has_boundary_edge[u] = true;
169+
}
170+
});
171+
172+
// Now add boundary edges to the graph
173+
for (size_t i = 0; i < has_boundary_edge.size(); i++) {
174+
if (has_boundary_edge[i])
175+
boundary_edge_func(i, boundary_edge_weights[i], boundary_edge_observables[i]);
176+
}
177+
return normalising_constant;
178+
}
179+
144180
UserGraph detector_error_model_to_user_graph(const stim::DetectorErrorModel& detector_error_model);
145181

146182
} // namespace pm

tests/matching/decode_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,24 @@ def test_decode_to_edges():
276276
m.add_edge(i, i + 1)
277277
edges = m.decode_to_edges_array([0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0])
278278
assert np.array_equal(edges, np.array([[9, 8], [5, 6], [4, 3], [5, 4], [0, 1], [0, -1]], dtype=np.int64))
279+
280+
281+
def test_parallel_boundary_edges_decoding():
282+
m = Matching()
283+
m.set_boundary_nodes({0, 2})
284+
m.add_edge(0, 1, fault_ids=0, weight=3.5)
285+
m.add_edge(1, 2, fault_ids=1, weight=2.5)
286+
assert np.array_equal(m.decode([0, 1]), np.array([0, 1], dtype=np.uint8))
287+
m.add_boundary_edge(1, fault_ids=100, weight=100)
288+
# Test pm::SearchGraph
289+
assert np.array_equal(np.nonzero(m.decode([0, 1]))[0], np.array([1], dtype=int))
290+
291+
m = Matching()
292+
m.add_edge(0, 1, fault_ids=0, weight=-1)
293+
m.add_edge(0, 2, fault_ids=1, weight=3)
294+
m.add_boundary_edge(0, fault_ids=2, weight=-0.5)
295+
m.add_edge(0, 3, fault_ids=3, weight=-3)
296+
m.add_edge(0, 4, fault_ids=4, weight=-2)
297+
assert np.array_equal(m.decode([1, 0, 0, 0, 0]), np.array([0, 0, 1, 0, 0], dtype=np.uint8))
298+
m.set_boundary_nodes({1, 2, 3, 4})
299+
assert np.array_equal(m.decode([1, 0, 0, 0, 0]), np.array([0, 0, 0, 1, 0], dtype=np.uint8))

0 commit comments

Comments
 (0)