Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
{os: macos-latest, dist: cp312-macosx_x86_64},
{os: macos-latest, dist: cp313-macosx_x86_64},
# macosx arm64
{os: macos-latest, dist: cp38-macosx_arm64},
# {os: macos-latest, dist: cp38-macosx_arm64},
{os: macos-latest, dist: cp39-macosx_arm64},
{os: macos-latest, dist: cp310-macosx_arm64},
{os: macos-latest, dist: cp311-macosx_arm64},
Expand Down
38 changes: 24 additions & 14 deletions src/pymatching/sparse_blossom/driver/user_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
#include "pymatching/rand/rand_gen.h"
#include "pymatching/sparse_blossom/driver/implied_weights.h"

namespace {

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

} // namespace


double pm::to_weight_for_correlations(double probability) {
return std::log((1 - probability) / probability);
}

double pm::merge_weights(double a, double b) {
auto sgn = std::copysign(1, a) * std::copysign(1, b);
auto signed_min = sgn * std::min(std::abs(a), std::abs(b));
Expand Down Expand Up @@ -342,6 +355,15 @@ void pm::UserGraph::handle_dem_instruction(
}
}

void pm::UserGraph::handle_dem_instruction_include_correlations(
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables) {
if (detectors.size() == 2) {
add_or_merge_edge(detectors[0], detectors[1], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
} else if (detectors.size() == 1) {
add_or_merge_boundary_edge(detectors[0], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
}
}

void pm::UserGraph::get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes) {
auto& mwpm = get_mwpm_with_search_graph();
bool src_is_boundary = is_boundary_node(src);
Expand Down Expand Up @@ -444,18 +466,6 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin
}
}

namespace {

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

double to_weight(double probability) {
return std::log((1 - probability) / probability);
}

} // namespace

void pm::add_decomposed_error_to_joint_probabilities(
DecomposedDemError& error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
Expand Down Expand Up @@ -490,7 +500,7 @@ pm::UserGraph pm::detector_error_model_to_user_graph(
pm::iter_dem_instructions_include_correlations(
detector_error_model,
[&](double p, const std::vector<size_t>& detectors, std::vector<size_t>& observables) {
user_graph.handle_dem_instruction(p, detectors, observables);
user_graph.handle_dem_instruction_include_correlations(p, detectors, observables);
},
joint_probabilites);

Expand Down Expand Up @@ -526,7 +536,7 @@ void pm::UserGraph::populate_implied_edge_weights(
// minimum of 0.5 as an implied probability for an edge to be reweighted.
double implied_probability_for_other_edge =
std::min(0.5, affected_edge_and_probability.second / marginal_probability);
double w = to_weight(implied_probability_for_other_edge);
double w = pm::to_weight_for_correlations(implied_probability_for_other_edge);
ImpliedWeightUnconverted implied{affected_edge.first, affected_edge.second, w};
edge.implied_weights_for_other_edges.push_back(implied);
}
Expand Down
54 changes: 26 additions & 28 deletions src/pymatching/sparse_blossom/driver/user_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class UserGraph {
Mwpm& get_mwpm();
Mwpm& get_mwpm_with_search_graph();
void handle_dem_instruction(double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
void handle_dem_instruction_include_correlations(
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
void get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes);
void populate_implied_edge_weights(
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites);
Expand All @@ -131,6 +133,8 @@ class UserGraph {
bool _all_edges_have_error_probabilities;
};

double to_weight_for_correlations(double probability);

template <typename EdgeCallable, typename BoundaryEdgeCallable>
inline double UserGraph::iter_discretized_edges(
pm::weight_int num_distinct_weights,
Expand Down Expand Up @@ -263,7 +267,8 @@ template <typename Handler>
void iter_dem_instructions_include_correlations(
const stim::DetectorErrorModel& detector_error_model,
const Handler& handle_dem_error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites,
bool include_decomposed_error_components_in_edge_weights = true) {
detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) {
double p = instruction.arg_data[0];
pm::DecomposedDemError decomposed_err;
Expand Down Expand Up @@ -311,35 +316,23 @@ void iter_dem_instructions_include_correlations(
component->observable_indices.push_back(target.val());
} else if (target.is_separator()) {
instruction_contains_separator = true;
// If the previous error in the decomposition had 3 or more detectors, we throw an exception.
if (num_component_detectors > 2) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with a hyperedge component (3 or more detectors). "
"This is not supported.");
} else if (num_component_detectors == 0) {
// We cannot have num_component_detectors > 2 at this point, or we would have already thrown an
// exception
if (num_component_detectors == 0) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
"This is not supported.");
} else if (num_component_detectors > 0) {
// If the previous error in the decomposition had 1 or 2 detectors, we handle it
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
// The previous error in the decomposition must have 1 or 2 detectors
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
}

if (num_component_detectors > 2) {
// Undecomposed hyperedges are not supported
throw std::invalid_argument(
"Encountered an undecomposed error instruction with 3 or mode detectors. "
"This is not supported when using `enable_correlations=True`. "
"Did you forget to set `decompose_errors=True` when "
"converting the stim circuit to a detector error model?");
} else if (num_component_detectors == 0) {
if (num_component_detectors == 0) {
if (instruction_contains_separator) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
Expand All @@ -348,12 +341,17 @@ void iter_dem_instructions_include_correlations(
// Ignore errors that are undetectable, provided they are not a component of a decomposed error
return;
}
}

} else if (num_component_detectors > 0) {
if (component->node2 == SIZE_MAX) {
handle_dem_error(p, {component->node1}, component->observable_indices);
} else {
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
// If include_decomposed_error_components_in_edge_weights is False, then only add the edge into
// the graph if it is not a component in a decomposed error with more than one component
if (include_decomposed_error_components_in_edge_weights || decomposed_err.components.size() == 1) {
for (pm::UserEdge& component : decomposed_err.components) {
if (component.node2 == SIZE_MAX) {
handle_dem_error(p, {component.node1}, component.observable_indices);
} else {
handle_dem_error(p, {component.node1, component.node2}, component.observable_indices);
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions src/pymatching/sparse_blossom/driver/user_graph.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) {
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 4);

std::vector<HandledError> expected = {
{0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}};
EXPECT_EQ(handler.handled_errors, expected);
Expand Down Expand Up @@ -485,14 +485,6 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {

graph.populate_implied_edge_weights(joint_probabilities);

auto to_weight = [](double p) {
if (p == 1.0)
return -std::numeric_limits<double>::infinity();
if (p == 0.0)
return std::numeric_limits<double>::infinity();
return std::log((1 - p) / p);
};

auto it_01 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
return edge.node1 == 0 && edge.node2 == 1;
});
Expand All @@ -503,7 +495,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
ASSERT_EQ(implied_01.node2, 3);

double p_01 = 0.1 / 0.26;
double w_01 = to_weight(p_01);
double w_01 = pm::to_weight_for_correlations(p_01);
ASSERT_EQ(implied_01.implied_weight, w_01);

auto it_23 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
Expand All @@ -515,7 +507,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
const auto& implied_23 = it_23->implied_weights_for_other_edges[0];
ASSERT_EQ(implied_23.node1, 0);
ASSERT_EQ(implied_23.node2, 1);
ASSERT_EQ(implied_23.implied_weight, 0);
ASSERT_NEAR(implied_23.implied_weight, 0.0, 0.00001);
}

TEST(UserGraph, ConvertImpliedWeights) {
Expand Down
20 changes: 20 additions & 0 deletions tests/matching/decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,26 @@ def test_decode_to_edges_with_correlations():
assert np.array_equal(edges, expected_edges)


def test_correlated_matching_handles_single_detector_components():
stim = pytest.importorskip("stim")
p = 0.1
circuit = stim.Circuit.generated(
code_task="surface_code:rotated_memory_x",
distance=5,
rounds=5,
before_round_data_depolarization=p,
)
circ_str = str(circuit).replace(
f"DEPOLARIZE1({p})", f"PAULI_CHANNEL_1(0, {p}, 0)"
)
noisy_circuit = stim.Circuit(circ_str)
dem = noisy_circuit.detector_error_model(
decompose_errors=True, approximate_disjoint_errors=True
)
m = Matching.from_detector_error_model(dem, enable_correlations=True)
assert m.num_detectors > 0


def test_load_from_circuit_with_correlations():
stim = pytest.importorskip("stim")
circuit = stim.Circuit.generated(
Expand Down
Loading