Skip to content

Commit

Permalink
Fix circuit.explain_dem_errors not supporting all gates (#704)
Browse files Browse the repository at this point in the history
- Fix a nasty miscomputed allocation size om `stim::MonotonicBuffer`
- Add `stim::inplace_xor_sort` C++ helper method
- Add support for `MXX`, `MYY`, `MZZ`, `HERALDED_ERASE`,
`HERALDED_PAULI_CHANNEL_1`, `MPAD` to `stim::ErrorMatcher`
- Add a unit test verifying `stim::ErrorMatcher` supports all gates

Fixes #697
  • Loading branch information
Strilanc authored Feb 29, 2024
1 parent f11b4f9 commit 61149a9
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/stim/mem/monotonic_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct MonotonicBuffer {
return;
}

size_t alloc_count = std::max(min_required, cur.size() << 1);
size_t alloc_count = std::max(min_required + tail.size(), cur.size() << 1);
if (cur.ptr_start != nullptr) {
old_areas.push_back(cur);
}
Expand Down
14 changes: 13 additions & 1 deletion src/stim/mem/monotonic_buffer.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST(pointer_range, equality) {
ASSERT_NE(r1, r2);
}

TEST(monotonic_buffer, x) {
TEST(monotonic_buffer, append_tail) {
MonotonicBuffer<int> buf;
for (size_t k = 0; k < 100; k++) {
buf.append_tail(k);
Expand All @@ -51,3 +51,15 @@ TEST(monotonic_buffer, x) {
ASSERT_EQ(rng[k], k);
}
}

TEST(monotonic_buffer, ensure_available) {
MonotonicBuffer<int> buf;
buf.append_tail(std::vector<int>{1, 2, 3, 4});
buf.append_tail(std::vector<int>{5, 6});
buf.append_tail(std::vector<int>{7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

SpanRef<const int> rng = buf.commit_tail();
std::vector<int> expected{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
SpanRef<const int> v = expected;
ASSERT_EQ(rng, v);
}
17 changes: 17 additions & 0 deletions src/stim/mem/sparse_xor_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ inline T *xor_merge_sort(SpanRef<const T> sorted_in1, SpanRef<const T> sorted_in
return out;
}

template <typename T>
inline SpanRef<T> inplace_xor_sort(SpanRef<T> items) {
std::sort(items.begin(), items.end());
size_t new_size = 0;
for (size_t k = 0; k < items.size(); k++) {
if (new_size > 0 && items[k] == items[new_size - 1]) {
new_size--;
} else {
if (k != new_size) {
std::swap(items[new_size], items[k]);
}
new_size++;
}
}
return items.sub(0, new_size);
}

template <typename T>
bool is_subset_of_sorted(SpanRef<const T> subset, SpanRef<const T> superset) {
const T *p_sub = subset.ptr_start;
Expand Down
19 changes: 19 additions & 0 deletions src/stim/mem/sparse_xor_vec.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,22 @@ TEST(sparse_xor_vec, contains) {
ASSERT_FALSE((SparseXorVec<uint32_t>{{}}).contains(0));
ASSERT_FALSE((SparseXorVec<uint32_t>{{1}}).contains(0));
}

TEST(sparse_xor_vec, inplace_xor_sort) {
auto f = [](std::vector<int> v) -> std::vector<int> {
SpanRef<int> s = v;
auto r = inplace_xor_sort(s);
v.resize(r.size());
return v;
};
ASSERT_EQ(f({}), (std::vector<int>({})));
ASSERT_EQ(f({5}), (std::vector<int>({5})));
ASSERT_EQ(f({5, 5}), (std::vector<int>({})));
ASSERT_EQ(f({5, 5, 5}), (std::vector<int>({5})));
ASSERT_EQ(f({5, 5, 5, 5}), (std::vector<int>({})));
ASSERT_EQ(f({5, 4, 5, 5}), (std::vector<int>({4, 5})));
ASSERT_EQ(f({4, 5, 5, 5}), (std::vector<int>({4, 5})));
ASSERT_EQ(f({5, 5, 5, 4}), (std::vector<int>({4, 5})));
ASSERT_EQ(f({4, 5, 5, 4}), (std::vector<int>({})));
ASSERT_EQ(f({3, 5, 5, 4}), (std::vector<int>({3, 4})));
}
2 changes: 1 addition & 1 deletion src/stim/simulators/error_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ void ErrorAnalyzer::undo_DEPOLARIZE2(const CircuitInstruction &dat) {

void ErrorAnalyzer::undo_ELSE_CORRELATED_ERROR(const CircuitInstruction &dat) {
if (accumulate_errors) {
throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR" + dat.str());
throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR: " + dat.str());
}
}

Expand Down
189 changes: 144 additions & 45 deletions src/stim/simulators/error_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,7 @@ ErrorMatcher::ErrorMatcher(
}
}

void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
assert(error_analyzer.error_class_probabilities.empty());
error_analyzer.undo_gate(effect);
if (error_analyzer.error_class_probabilities.empty()) {
/// Maybe there were no detectors or observables nearby? Or the noise probability was zero?
return;
}

assert(error_analyzer.error_class_probabilities.size() == 1);
SpanRef<const DemTarget> dem_error_terms = error_analyzer.error_class_probabilities.begin()->first;
void ErrorMatcher::add_dem_error_terms(SpanRef<const DemTarget> dem_error_terms) {
auto entry = output_map.find(dem_error_terms);
if (!dem_error_terms.empty() && (allow_adding_new_dem_errors_to_output_map || entry != output_map.end())) {
// We have a desired match! Record it.
Expand All @@ -88,6 +79,19 @@ void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
out[0] = std::move(new_loc);
}
}
}

void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
assert(error_analyzer.error_class_probabilities.empty());
error_analyzer.undo_gate(effect);
if (error_analyzer.error_class_probabilities.empty()) {
/// Maybe there were no detectors or observables nearby? Or the noise probability was zero?
return;
}

assert(error_analyzer.error_class_probabilities.size() == 1);
SpanRef<const DemTarget> dem_error_terms = error_analyzer.error_class_probabilities.begin()->first;
add_dem_error_terms(dem_error_terms);

// Restore the pristine state.
error_analyzer.mono_buf.clear();
Expand Down Expand Up @@ -128,6 +132,58 @@ void ErrorMatcher::err_xyz(const CircuitInstruction &op, uint32_t target_flags)
}
}

void ErrorMatcher::err_heralded_pauli_channel_1(const CircuitInstruction &op) {
assert(op.args.size() == 4);
for (size_t k = op.targets.size(); k--;) {
auto q = op.targets[k].qubit_value();
cur_loc.instruction_targets.target_range_start = k;
cur_loc.instruction_targets.target_range_end = k + 1;

cur_loc.flipped_measurement.measurement_record_index = error_analyzer.tracker.num_measurements_in_past - 1;
SpanRef<const DemTarget> herald_symptoms = error_analyzer.tracker.rec_bits[error_analyzer.tracker.num_measurements_in_past - 1].range();
SpanRef<const DemTarget> x_symptoms = error_analyzer.tracker.zs[q].range();
SpanRef<const DemTarget> z_symptoms = error_analyzer.tracker.xs[q].range();
if (op.args[0] != 0) {
add_dem_error_terms(herald_symptoms);
}
if (op.args[1] != 0) {
error_analyzer.mono_buf.append_tail(herald_symptoms);
error_analyzer.mono_buf.append_tail(x_symptoms);
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT, cur_loc.flipped_pauli_product);
add_dem_error_terms(error_analyzer.mono_buf.tail);
cur_loc.flipped_pauli_product.clear();
error_analyzer.mono_buf.discard_tail();
}
if (op.args[2] != 0) {
error_analyzer.mono_buf.append_tail(herald_symptoms);
error_analyzer.mono_buf.append_tail(x_symptoms);
error_analyzer.mono_buf.append_tail(z_symptoms);
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product);
add_dem_error_terms(error_analyzer.mono_buf.tail);
cur_loc.flipped_pauli_product.clear();
error_analyzer.mono_buf.discard_tail();
}
if (op.args[3] != 0) {
error_analyzer.mono_buf.append_tail(herald_symptoms);
error_analyzer.mono_buf.append_tail(z_symptoms);
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
resolve_paulis_into(&op.targets[k], TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product);
add_dem_error_terms(error_analyzer.mono_buf.tail);
cur_loc.flipped_pauli_product.clear();
error_analyzer.mono_buf.discard_tail();
}
cur_loc.flipped_measurement.measurement_record_index = UINT64_MAX;

assert(error_analyzer.error_class_probabilities.empty());
error_analyzer.tracker.undo_gate(op);
error_analyzer.mono_buf.clear();
error_analyzer.error_class_probabilities.clear();
error_analyzer.flushed_reversed_model.clear();
}
}

void ErrorMatcher::err_pauli_channel_1(const CircuitInstruction &op) {
const auto &a = op.args;
const auto &t = op.targets;
Expand Down Expand Up @@ -187,12 +243,17 @@ void ErrorMatcher::err_m(const CircuitInstruction &op, uint32_t obs_mask) {
const auto &t = op.targets;
const auto &a = op.args;

bool q2 = GATE_DATA[op.gate_type].flags & GATE_TARGETS_PAIRS;
size_t end = t.size();
while (end > 0) {
size_t start = end - 1;
while (start > 0 && t[start - 1].is_combiner()) {
start -= std::min(start, size_t{2});
}
if (q2) {
start--;
}


SpanRef<const GateTarget> slice{t.begin() + start, t.begin() + end};

Expand Down Expand Up @@ -227,48 +288,86 @@ void ErrorMatcher::rev_process_instruction(const CircuitInstruction &op) {
entry->second.push_back(d);
}
}
return;
} else if (op.gate_type == GateType::SHIFT_COORDS) {
error_analyzer.undo_SHIFT_COORDS(op);
for (size_t k = 0; k < op.args.size(); k++) {
cur_coord_offset[k] -= op.args[k];
}
return;
} else if (!(flags & (GATE_IS_NOISY | GATE_PRODUCES_RESULTS))) {
error_analyzer.undo_gate(op);
} else if (op.gate_type == GateType::E || op.gate_type == GateType::ELSE_CORRELATED_ERROR) {
cur_loc.instruction_targets.target_range_start = 0;
cur_loc.instruction_targets.target_range_end = op.targets.size();
resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product);
err_atom(op);
cur_loc.flipped_pauli_product.clear();
} else if (op.gate_type == GateType::X_ERROR) {
err_xyz(op, TARGET_PAULI_X_BIT);
} else if (op.gate_type == GateType::Y_ERROR) {
err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
} else if (op.gate_type == GateType::Z_ERROR) {
err_xyz(op, TARGET_PAULI_Z_BIT);
} else if (op.gate_type == GateType::PAULI_CHANNEL_1) {
err_pauli_channel_1(op);
} else if (op.gate_type == GateType::DEPOLARIZE1) {
float p = op.args[0];
std::array<double, 3> spread{p, p, p};
err_pauli_channel_1({op.gate_type, spread, op.targets});
} else if (op.gate_type == GateType::PAULI_CHANNEL_2) {
err_pauli_channel_2(op);
} else if (op.gate_type == GateType::DEPOLARIZE2) {
float p = op.args[0];
std::array<double, 15> spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p};
err_pauli_channel_2({op.gate_type, spread, op.targets});
} else if (op.gate_type == GateType::MPP) {
err_m(op, 0);
} else if (op.gate_type == GateType::MX || op.gate_type == GateType::MRX) {
err_m(op, TARGET_PAULI_X_BIT);
} else if (op.gate_type == GateType::MY || op.gate_type == GateType::MRY) {
err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
} else if (op.gate_type == GateType::M || op.gate_type == GateType::MR) {
err_m(op, TARGET_PAULI_Z_BIT);
} else {
throw std::invalid_argument(
"Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name));
return;
}
switch (op.gate_type) {
case GateType::MPAD:
error_analyzer.undo_gate(op);
break;
case GateType::E:
case GateType::ELSE_CORRELATED_ERROR: {
cur_loc.instruction_targets.target_range_start = 0;
cur_loc.instruction_targets.target_range_end = op.targets.size();
resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product);
CircuitInstruction op2 = op;
op2.gate_type = GateType::E;
err_atom(op2);
cur_loc.flipped_pauli_product.clear();
break;
} case GateType::X_ERROR:
err_xyz(op, TARGET_PAULI_X_BIT);
break;
case GateType::Y_ERROR:
err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
break;
case GateType::Z_ERROR:
err_xyz(op, TARGET_PAULI_Z_BIT);
break;
case GateType::PAULI_CHANNEL_1:
err_pauli_channel_1(op);
break;
case GateType::HERALDED_PAULI_CHANNEL_1:
err_heralded_pauli_channel_1(op);
break;
case GateType::HERALDED_ERASE: {
float p = op.args[0] / 4;
std::array<double, 4> spread{p, p, p, p};
err_heralded_pauli_channel_1({op.gate_type, spread, op.targets});
break;
} case GateType::DEPOLARIZE1: {
float p = op.args[0];
std::array<double, 3> spread{p, p, p};
err_pauli_channel_1({op.gate_type, spread, op.targets});
break;
} case GateType::PAULI_CHANNEL_2:
err_pauli_channel_2(op);
break;
case GateType::DEPOLARIZE2: {
float p = op.args[0];
std::array<double, 15> spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p};
err_pauli_channel_2({op.gate_type, spread, op.targets});
break;
}
case GateType::MPP:
err_m(op, 0);
break;
case GateType::MX:
case GateType::MRX:
case GateType::MXX:
err_m(op, TARGET_PAULI_X_BIT);
break;
case GateType::MY:
case GateType::MRY:
case GateType::MYY:
err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
break;
case GateType::M:
case GateType::MR:
case GateType::MZZ:
err_m(op, TARGET_PAULI_Z_BIT);
break;
default:
throw std::invalid_argument(
"Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name));
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/stim/simulators/error_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ struct ErrorMatcher {
void err_atom(const CircuitInstruction &effect);
/// Processes operations with X, Y, Z errors on each target.
void err_pauli_channel_1(const CircuitInstruction &op);
/// Processes operations with M, X, Y, Z errors on each target.
void err_heralded_pauli_channel_1(const CircuitInstruction &op);
/// Processes operations with 15 two-qubit Pauli product errors on each target pair.
void err_pauli_channel_2(const CircuitInstruction &op);
/// Processes measurement operations.
Expand All @@ -88,6 +90,8 @@ struct ErrorMatcher {
void rev_process_instruction(const CircuitInstruction &op);
/// Processes entire circuits.
void rev_process_circuit(uint64_t reps, const Circuit &block);

void add_dem_error_terms(SpanRef<const DemTarget> dem_error_terms);
};

} // namespace stim
Expand Down
Loading

0 comments on commit 61149a9

Please sign in to comment.