From 61149a97796c4374ff0e554b32542cd3c77a56e8 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Wed, 28 Feb 2024 18:52:48 -0800 Subject: [PATCH] Fix `circuit.explain_dem_errors` not supporting all gates (#704) - Fix a nasty miscomputed allocation size om `stim::MonotonicBuffer` - Add `stim::inplace_xor_sort` C++ helper method - Add support for `MXX`, `MYY`, `MZZ`, `HERALDED_ERASE`, `HERALDED_PAULI_CHANNEL_1`, `MPAD` to `stim::ErrorMatcher` - Add a unit test verifying `stim::ErrorMatcher` supports all gates Fixes https://github.com/quantumlib/Stim/issues/697 --- src/stim/mem/monotonic_buffer.h | 2 +- src/stim/mem/monotonic_buffer.test.cc | 14 +- src/stim/mem/sparse_xor_vec.h | 17 ++ src/stim/mem/sparse_xor_vec.test.cc | 19 +++ src/stim/simulators/error_analyzer.cc | 2 +- src/stim/simulators/error_matcher.cc | 189 ++++++++++++++++------ src/stim/simulators/error_matcher.h | 4 + src/stim/simulators/error_matcher.test.cc | 140 ++++++++++++++++ src/stim/simulators/matched_error.cc | 2 + 9 files changed, 341 insertions(+), 48 deletions(-) diff --git a/src/stim/mem/monotonic_buffer.h b/src/stim/mem/monotonic_buffer.h index 4f8f240eb..8b6998dcf 100644 --- a/src/stim/mem/monotonic_buffer.h +++ b/src/stim/mem/monotonic_buffer.h @@ -155,7 +155,7 @@ struct MonotonicBuffer { return; } - size_t alloc_count = std::max(min_required, cur.size() << 1); + size_t alloc_count = std::max(min_required + tail.size(), cur.size() << 1); if (cur.ptr_start != nullptr) { old_areas.push_back(cur); } diff --git a/src/stim/mem/monotonic_buffer.test.cc b/src/stim/mem/monotonic_buffer.test.cc index 2c2af9b1b..6cb0b0fc9 100644 --- a/src/stim/mem/monotonic_buffer.test.cc +++ b/src/stim/mem/monotonic_buffer.test.cc @@ -39,7 +39,7 @@ TEST(pointer_range, equality) { ASSERT_NE(r1, r2); } -TEST(monotonic_buffer, x) { +TEST(monotonic_buffer, append_tail) { MonotonicBuffer buf; for (size_t k = 0; k < 100; k++) { buf.append_tail(k); @@ -51,3 +51,15 @@ TEST(monotonic_buffer, x) { ASSERT_EQ(rng[k], k); } } + +TEST(monotonic_buffer, ensure_available) { + MonotonicBuffer buf; + buf.append_tail(std::vector{1, 2, 3, 4}); + buf.append_tail(std::vector{5, 6}); + buf.append_tail(std::vector{7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + SpanRef rng = buf.commit_tail(); + std::vector expected{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + SpanRef v = expected; + ASSERT_EQ(rng, v); +} diff --git a/src/stim/mem/sparse_xor_vec.h b/src/stim/mem/sparse_xor_vec.h index 964ac5d73..7bb4b5390 100644 --- a/src/stim/mem/sparse_xor_vec.h +++ b/src/stim/mem/sparse_xor_vec.h @@ -60,6 +60,23 @@ inline T *xor_merge_sort(SpanRef sorted_in1, SpanRef sorted_in return out; } +template +inline SpanRef inplace_xor_sort(SpanRef items) { + std::sort(items.begin(), items.end()); + size_t new_size = 0; + for (size_t k = 0; k < items.size(); k++) { + if (new_size > 0 && items[k] == items[new_size - 1]) { + new_size--; + } else { + if (k != new_size) { + std::swap(items[new_size], items[k]); + } + new_size++; + } + } + return items.sub(0, new_size); +} + template bool is_subset_of_sorted(SpanRef subset, SpanRef superset) { const T *p_sub = subset.ptr_start; diff --git a/src/stim/mem/sparse_xor_vec.test.cc b/src/stim/mem/sparse_xor_vec.test.cc index 29174602a..042f73e12 100644 --- a/src/stim/mem/sparse_xor_vec.test.cc +++ b/src/stim/mem/sparse_xor_vec.test.cc @@ -171,3 +171,22 @@ TEST(sparse_xor_vec, contains) { ASSERT_FALSE((SparseXorVec{{}}).contains(0)); ASSERT_FALSE((SparseXorVec{{1}}).contains(0)); } + +TEST(sparse_xor_vec, inplace_xor_sort) { + auto f = [](std::vector v) -> std::vector { + SpanRef s = v; + auto r = inplace_xor_sort(s); + v.resize(r.size()); + return v; + }; + ASSERT_EQ(f({}), (std::vector({}))); + ASSERT_EQ(f({5}), (std::vector({5}))); + ASSERT_EQ(f({5, 5}), (std::vector({}))); + ASSERT_EQ(f({5, 5, 5}), (std::vector({5}))); + ASSERT_EQ(f({5, 5, 5, 5}), (std::vector({}))); + ASSERT_EQ(f({5, 4, 5, 5}), (std::vector({4, 5}))); + ASSERT_EQ(f({4, 5, 5, 5}), (std::vector({4, 5}))); + ASSERT_EQ(f({5, 5, 5, 4}), (std::vector({4, 5}))); + ASSERT_EQ(f({4, 5, 5, 4}), (std::vector({}))); + ASSERT_EQ(f({3, 5, 5, 4}), (std::vector({3, 4}))); +} diff --git a/src/stim/simulators/error_analyzer.cc b/src/stim/simulators/error_analyzer.cc index 9f8e5245c..085074c7b 100644 --- a/src/stim/simulators/error_analyzer.cc +++ b/src/stim/simulators/error_analyzer.cc @@ -837,7 +837,7 @@ void ErrorAnalyzer::undo_DEPOLARIZE2(const CircuitInstruction &dat) { void ErrorAnalyzer::undo_ELSE_CORRELATED_ERROR(const CircuitInstruction &dat) { if (accumulate_errors) { - throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR" + dat.str()); + throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR: " + dat.str()); } } diff --git a/src/stim/simulators/error_matcher.cc b/src/stim/simulators/error_matcher.cc index 8d7f3157b..aa0f36293 100644 --- a/src/stim/simulators/error_matcher.cc +++ b/src/stim/simulators/error_matcher.cc @@ -59,16 +59,7 @@ ErrorMatcher::ErrorMatcher( } } -void ErrorMatcher::err_atom(const CircuitInstruction &effect) { - assert(error_analyzer.error_class_probabilities.empty()); - error_analyzer.undo_gate(effect); - if (error_analyzer.error_class_probabilities.empty()) { - /// Maybe there were no detectors or observables nearby? Or the noise probability was zero? - return; - } - - assert(error_analyzer.error_class_probabilities.size() == 1); - SpanRef dem_error_terms = error_analyzer.error_class_probabilities.begin()->first; +void ErrorMatcher::add_dem_error_terms(SpanRef dem_error_terms) { auto entry = output_map.find(dem_error_terms); if (!dem_error_terms.empty() && (allow_adding_new_dem_errors_to_output_map || entry != output_map.end())) { // We have a desired match! Record it. @@ -88,6 +79,19 @@ void ErrorMatcher::err_atom(const CircuitInstruction &effect) { out[0] = std::move(new_loc); } } +} + +void ErrorMatcher::err_atom(const CircuitInstruction &effect) { + assert(error_analyzer.error_class_probabilities.empty()); + error_analyzer.undo_gate(effect); + if (error_analyzer.error_class_probabilities.empty()) { + /// Maybe there were no detectors or observables nearby? Or the noise probability was zero? + return; + } + + assert(error_analyzer.error_class_probabilities.size() == 1); + SpanRef dem_error_terms = error_analyzer.error_class_probabilities.begin()->first; + add_dem_error_terms(dem_error_terms); // Restore the pristine state. error_analyzer.mono_buf.clear(); @@ -128,6 +132,58 @@ void ErrorMatcher::err_xyz(const CircuitInstruction &op, uint32_t target_flags) } } +void ErrorMatcher::err_heralded_pauli_channel_1(const CircuitInstruction &op) { + assert(op.args.size() == 4); + for (size_t k = op.targets.size(); k--;) { + auto q = op.targets[k].qubit_value(); + cur_loc.instruction_targets.target_range_start = k; + cur_loc.instruction_targets.target_range_end = k + 1; + + cur_loc.flipped_measurement.measurement_record_index = error_analyzer.tracker.num_measurements_in_past - 1; + SpanRef herald_symptoms = error_analyzer.tracker.rec_bits[error_analyzer.tracker.num_measurements_in_past - 1].range(); + SpanRef x_symptoms = error_analyzer.tracker.zs[q].range(); + SpanRef z_symptoms = error_analyzer.tracker.xs[q].range(); + if (op.args[0] != 0) { + add_dem_error_terms(herald_symptoms); + } + if (op.args[1] != 0) { + error_analyzer.mono_buf.append_tail(herald_symptoms); + error_analyzer.mono_buf.append_tail(x_symptoms); + error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail); + resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT, cur_loc.flipped_pauli_product); + add_dem_error_terms(error_analyzer.mono_buf.tail); + cur_loc.flipped_pauli_product.clear(); + error_analyzer.mono_buf.discard_tail(); + } + if (op.args[2] != 0) { + error_analyzer.mono_buf.append_tail(herald_symptoms); + error_analyzer.mono_buf.append_tail(x_symptoms); + error_analyzer.mono_buf.append_tail(z_symptoms); + error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail); + resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product); + add_dem_error_terms(error_analyzer.mono_buf.tail); + cur_loc.flipped_pauli_product.clear(); + error_analyzer.mono_buf.discard_tail(); + } + if (op.args[3] != 0) { + error_analyzer.mono_buf.append_tail(herald_symptoms); + error_analyzer.mono_buf.append_tail(z_symptoms); + error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail); + resolve_paulis_into(&op.targets[k], TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product); + add_dem_error_terms(error_analyzer.mono_buf.tail); + cur_loc.flipped_pauli_product.clear(); + error_analyzer.mono_buf.discard_tail(); + } + cur_loc.flipped_measurement.measurement_record_index = UINT64_MAX; + + assert(error_analyzer.error_class_probabilities.empty()); + error_analyzer.tracker.undo_gate(op); + error_analyzer.mono_buf.clear(); + error_analyzer.error_class_probabilities.clear(); + error_analyzer.flushed_reversed_model.clear(); + } +} + void ErrorMatcher::err_pauli_channel_1(const CircuitInstruction &op) { const auto &a = op.args; const auto &t = op.targets; @@ -187,12 +243,17 @@ void ErrorMatcher::err_m(const CircuitInstruction &op, uint32_t obs_mask) { const auto &t = op.targets; const auto &a = op.args; + bool q2 = GATE_DATA[op.gate_type].flags & GATE_TARGETS_PAIRS; size_t end = t.size(); while (end > 0) { size_t start = end - 1; while (start > 0 && t[start - 1].is_combiner()) { start -= std::min(start, size_t{2}); } + if (q2) { + start--; + } + SpanRef slice{t.begin() + start, t.begin() + end}; @@ -227,48 +288,86 @@ void ErrorMatcher::rev_process_instruction(const CircuitInstruction &op) { entry->second.push_back(d); } } + return; } else if (op.gate_type == GateType::SHIFT_COORDS) { error_analyzer.undo_SHIFT_COORDS(op); for (size_t k = 0; k < op.args.size(); k++) { cur_coord_offset[k] -= op.args[k]; } + return; } else if (!(flags & (GATE_IS_NOISY | GATE_PRODUCES_RESULTS))) { error_analyzer.undo_gate(op); - } else if (op.gate_type == GateType::E || op.gate_type == GateType::ELSE_CORRELATED_ERROR) { - cur_loc.instruction_targets.target_range_start = 0; - cur_loc.instruction_targets.target_range_end = op.targets.size(); - resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product); - err_atom(op); - cur_loc.flipped_pauli_product.clear(); - } else if (op.gate_type == GateType::X_ERROR) { - err_xyz(op, TARGET_PAULI_X_BIT); - } else if (op.gate_type == GateType::Y_ERROR) { - err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT); - } else if (op.gate_type == GateType::Z_ERROR) { - err_xyz(op, TARGET_PAULI_Z_BIT); - } else if (op.gate_type == GateType::PAULI_CHANNEL_1) { - err_pauli_channel_1(op); - } else if (op.gate_type == GateType::DEPOLARIZE1) { - float p = op.args[0]; - std::array spread{p, p, p}; - err_pauli_channel_1({op.gate_type, spread, op.targets}); - } else if (op.gate_type == GateType::PAULI_CHANNEL_2) { - err_pauli_channel_2(op); - } else if (op.gate_type == GateType::DEPOLARIZE2) { - float p = op.args[0]; - std::array spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p}; - err_pauli_channel_2({op.gate_type, spread, op.targets}); - } else if (op.gate_type == GateType::MPP) { - err_m(op, 0); - } else if (op.gate_type == GateType::MX || op.gate_type == GateType::MRX) { - err_m(op, TARGET_PAULI_X_BIT); - } else if (op.gate_type == GateType::MY || op.gate_type == GateType::MRY) { - err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT); - } else if (op.gate_type == GateType::M || op.gate_type == GateType::MR) { - err_m(op, TARGET_PAULI_Z_BIT); - } else { - throw std::invalid_argument( - "Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name)); + return; + } + switch (op.gate_type) { + case GateType::MPAD: + error_analyzer.undo_gate(op); + break; + case GateType::E: + case GateType::ELSE_CORRELATED_ERROR: { + cur_loc.instruction_targets.target_range_start = 0; + cur_loc.instruction_targets.target_range_end = op.targets.size(); + resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product); + CircuitInstruction op2 = op; + op2.gate_type = GateType::E; + err_atom(op2); + cur_loc.flipped_pauli_product.clear(); + break; + } case GateType::X_ERROR: + err_xyz(op, TARGET_PAULI_X_BIT); + break; + case GateType::Y_ERROR: + err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT); + break; + case GateType::Z_ERROR: + err_xyz(op, TARGET_PAULI_Z_BIT); + break; + case GateType::PAULI_CHANNEL_1: + err_pauli_channel_1(op); + break; + case GateType::HERALDED_PAULI_CHANNEL_1: + err_heralded_pauli_channel_1(op); + break; + case GateType::HERALDED_ERASE: { + float p = op.args[0] / 4; + std::array spread{p, p, p, p}; + err_heralded_pauli_channel_1({op.gate_type, spread, op.targets}); + break; + } case GateType::DEPOLARIZE1: { + float p = op.args[0]; + std::array spread{p, p, p}; + err_pauli_channel_1({op.gate_type, spread, op.targets}); + break; + } case GateType::PAULI_CHANNEL_2: + err_pauli_channel_2(op); + break; + case GateType::DEPOLARIZE2: { + float p = op.args[0]; + std::array spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p}; + err_pauli_channel_2({op.gate_type, spread, op.targets}); + break; + } + case GateType::MPP: + err_m(op, 0); + break; + case GateType::MX: + case GateType::MRX: + case GateType::MXX: + err_m(op, TARGET_PAULI_X_BIT); + break; + case GateType::MY: + case GateType::MRY: + case GateType::MYY: + err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT); + break; + case GateType::M: + case GateType::MR: + case GateType::MZZ: + err_m(op, TARGET_PAULI_Z_BIT); + break; + default: + throw std::invalid_argument( + "Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name)); } } diff --git a/src/stim/simulators/error_matcher.h b/src/stim/simulators/error_matcher.h index ae158b02f..9831e0a4c 100644 --- a/src/stim/simulators/error_matcher.h +++ b/src/stim/simulators/error_matcher.h @@ -78,6 +78,8 @@ struct ErrorMatcher { void err_atom(const CircuitInstruction &effect); /// Processes operations with X, Y, Z errors on each target. void err_pauli_channel_1(const CircuitInstruction &op); + /// Processes operations with M, X, Y, Z errors on each target. + void err_heralded_pauli_channel_1(const CircuitInstruction &op); /// Processes operations with 15 two-qubit Pauli product errors on each target pair. void err_pauli_channel_2(const CircuitInstruction &op); /// Processes measurement operations. @@ -88,6 +90,8 @@ struct ErrorMatcher { void rev_process_instruction(const CircuitInstruction &op); /// Processes entire circuits. void rev_process_circuit(uint64_t reps, const Circuit &block); + + void add_dem_error_terms(SpanRef dem_error_terms); }; } // namespace stim diff --git a/src/stim/simulators/error_matcher.test.cc b/src/stim/simulators/error_matcher.test.cc index 90227d3e2..1282c4dc1 100644 --- a/src/stim/simulators/error_matcher.test.cc +++ b/src/stim/simulators/error_matcher.test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" +#include "stim/circuit/circuit.test.h" #include "stim/gen/gen_rep_code.h" #include "stim/gen/gen_surface_code.h" @@ -153,6 +154,126 @@ TEST(ErrorMatcher, MPP_ERROR) { })RESULT"); } +TEST(ErrorMatcher, MXX_ERROR) { + auto actual = ErrorMatcher::explain_errors_from_circuit( + Circuit(R"CIRCUIT( + QUBIT_COORDS(5, 6) 0 + RX 0 + CX 0 1 + MXX(0.125) 0 1 + DETECTOR(2, 3) rec[-1] + )CIRCUIT"), + nullptr, + false); + ASSERT_EQ(actual.size(), 1); + ASSERT_EQ(actual[0].str(), R"RESULT(ExplainedError { + dem_error_terms: D0[coords 2,3] + CircuitErrorLocation { + flipped_measurement.measurement_record_index: 0 + flipped_measurement.measured_observable: X0[coords 5,6]*X1 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (MXX) in the circuit + at targets #1 to #2 of the instruction + resolving to MXX(0.125) 0[coords 5,6] 1 + } +})RESULT"); +} + +TEST(ErrorMatcher, ELSE_CORRELATED_ERROR) { + auto actual = ErrorMatcher::explain_errors_from_circuit( + Circuit(R"CIRCUIT( + R 0 1 + H 1 + CORRELATED_ERROR(0.25) X0 + ELSE_CORRELATED_ERROR(0.125) Z1 + H 1 + M 0 1 + DETECTOR rec[-1] + )CIRCUIT"), + nullptr, + false); + ASSERT_EQ(actual.size(), 1); + ASSERT_EQ(actual[0].str(), R"RESULT(ExplainedError { + dem_error_terms: D0 + CircuitErrorLocation { + flipped_pauli_product: Z1 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (ELSE_CORRELATED_ERROR) in the circuit + at target #1 of the instruction + resolving to ELSE_CORRELATED_ERROR(0.125) Z1 + } +})RESULT"); +} + +TEST(ErrorMatcher, HERALDED_ERASE) { + auto actual = ErrorMatcher::explain_errors_from_circuit( + Circuit(R"CIRCUIT( + MXX 0 1 + MYY 0 1 + MZZ 0 1 + HERALDED_ERASE(0.125) 0 + MXX 0 1 + MYY 0 1 + MZZ 0 1 + DETECTOR rec[-1] rec[-5] + DETECTOR rec[-2] rec[-6] + DETECTOR rec[-3] rec[-7] + DETECTOR rec[-4] + )CIRCUIT"), + nullptr, + false); + ASSERT_EQ(actual.size(), 4); + ASSERT_EQ(actual[0].str(), R"RESULT(ExplainedError { + dem_error_terms: D0 D1 D3 + CircuitErrorLocation { + flipped_pauli_product: X0 + flipped_measurement.measurement_record_index: 3 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (HERALDED_ERASE) in the circuit + at target #1 of the instruction + resolving to HERALDED_ERASE(0.125) 0 + } +})RESULT"); + ASSERT_EQ(actual[1].str(), R"RESULT(ExplainedError { + dem_error_terms: D0 D2 D3 + CircuitErrorLocation { + flipped_pauli_product: Y0 + flipped_measurement.measurement_record_index: 3 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (HERALDED_ERASE) in the circuit + at target #1 of the instruction + resolving to HERALDED_ERASE(0.125) 0 + } +})RESULT"); + ASSERT_EQ(actual[2].str(), R"RESULT(ExplainedError { + dem_error_terms: D1 D2 D3 + CircuitErrorLocation { + flipped_pauli_product: Z0 + flipped_measurement.measurement_record_index: 3 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (HERALDED_ERASE) in the circuit + at target #1 of the instruction + resolving to HERALDED_ERASE(0.125) 0 + } +})RESULT"); + ASSERT_EQ(actual[3].str(), R"RESULT(ExplainedError { + dem_error_terms: D3 + CircuitErrorLocation { + flipped_measurement.measurement_record_index: 3 + Circuit location stack trace: + (after 0 TICKs) + at instruction #4 (HERALDED_ERASE) in the circuit + at target #1 of the instruction + resolving to HERALDED_ERASE(0.125) 0 + } +})RESULT"); +} + TEST(ErrorMatcher, repetition_code_data_depolarization) { CircuitGenParameters params(2, 3, "memory"); params.before_round_data_depolarization = 0.001; @@ -403,3 +524,22 @@ ExplainedError { } )RESULT"); } + +TEST(ErrorMatcher, runs_on_all_gates_circuit) { + DetectorErrorModel filter(R"MODEL( + error(1) D0 +)MODEL"); + + auto circuit = generate_test_circuit_with_all_operations(); + auto actual = ErrorMatcher::explain_errors_from_circuit(circuit, &filter, false); + std::stringstream ss; + for (const auto &match : actual) { + ss << "\n" << match << "\n"; + } + ASSERT_EQ(ss.str(), R"RESULT( +ExplainedError { + dem_error_terms: D0[coords 2,4,6] + [no single circuit error had these exact symptoms] +} +)RESULT"); +} diff --git a/src/stim/simulators/matched_error.cc b/src/stim/simulators/matched_error.cc index 3e1ca684d..8ef03e91f 100644 --- a/src/stim/simulators/matched_error.cc +++ b/src/stim/simulators/matched_error.cc @@ -42,6 +42,8 @@ void print_circuit_error_loc_indent(std::ostream &out, const CircuitErrorLocatio out << indent << " flipped_measurement.measurement_record_index: " << e.flipped_measurement.measurement_record_index << "\n"; + } + if (!e.flipped_measurement.measured_observable.empty()) { out << indent << " flipped_measurement.measured_observable: "; print_pauli_product(out, e.flipped_measurement.measured_observable); out << "\n";