From 9811a6a8ab0d398e3eb44e3b6e7040ecccc3de46 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Fri, 31 Jan 2025 14:45:09 -0500 Subject: [PATCH] Fix bit pack of mwpf and fusion blossom decoders under multiple logical observable (#873) This PR fixed two bugs in MWPF decoder ## 1. Supporting decomposed detector error model While MWPF expects a decoding hypergraph, the input detector error model from sinter is by default decomposed. The decomposed DEM may contain the same detector or logical observable multiple times, which is not considered by the previous implementation. The previous implementation assumes that each detector and logical observable only appears once, thus, I used ```python frames: List[int] = [] ... frames.append(t.val) ``` However, this no longer works if the same frame appears in multiple decomposed parts. In this case, the DEM actually means that "the hyperedge contributes to the logical observable iff count(frame) % 2 == 1". This is fixed by ```python frames: set[int] = set() ... frames ^= { t.val } ``` ## 2. Supporting multiple logical observables Although a previous [PR #864](https://github.com/quantumlib/Stim/pull/864) has fixed the panic issue when multiple logical observables are encountered, the returned value is actually problematic and causes significantly higher logical error rate. The previous implementation converts a `int` typed bitmask to a bitpacked value using `np.packbits(prediction, bitorder="little")`. However, this doesn't work for more than one logical observables. For example, if I define an observable using `OBSERVABLE_INCLUDE(2) ...`, supposedly the bitpacked value should be `[4]` because $1<<2 = 4$. However, `np.packbits(4, bitorder="little") = [1]`, which is incorrect. The correct procedure is first generate the binary representation with `self.num_obs` bits using `np.binary_repr(prediction, width=self.num_obs)`, in this case, `'100'`, and then revert the order of the bits to `['0', '0', '1']`, and then run the packbits which gives us the correct value `[4]`. The full code is below: ```python predictions[shot] = np.packbits( np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8), bitorder="little", ) ``` --- .../_decoding/_decoding_fusion_blossom.py | 5 +- .../src/sinter/_decoding/_decoding_mwpf.py | 88 ++++++++----------- 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py index c43760fa..4d1b2cd5 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py +++ b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py @@ -31,7 +31,10 @@ def decode_shots_bit_packed( syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse) self.solver.solve(syndrome) prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])) - predictions[shot] = np.packbits(prediction, bitorder='little') + predictions[shot] = np.packbits( + np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8), + bitorder="little", + ) self.solver.clear() return predictions diff --git a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py index 915fd0b0..22536ff6 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py @@ -38,7 +38,9 @@ def decode_shots_bit_packed( bit_packed_detection_event_data: "np.ndarray", ) -> "np.ndarray": num_shots = bit_packed_detection_event_data.shape[0] - predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8) + predictions = np.zeros( + shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8 + ) import mwpf for shot in range(num_shots): @@ -58,29 +60,42 @@ def decode_shots_bit_packed( np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) ) self.solver.clear() - predictions[shot] = np.packbits(prediction, bitorder="little") + predictions[shot] = np.packbits( + np.array( + list(np.binary_repr(prediction, width=self.num_obs))[::-1], + dtype=np.uint8, + ), + bitorder="little", + ) return predictions class MwpfDecoder(Decoder): """Use MWPF to predict observables from detection events.""" - def compile_decoder_for_dem( + def __init__( self, - *, - dem: "stim.DetectorErrorModel", decoder_cls: Any = None, # decoder class used to construct the MWPF decoder. # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins` # but just provide different plugins for optimizing the primal and/or dual solutions. # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. - cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster, + ): + self.decoder_cls = decoder_cls + self.cluster_node_limit = cluster_node_limit + super().__init__() + + def compile_decoder_for_dem( + self, + *, + dem: "stim.DetectorErrorModel", ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( dem, - decoder_cls=decoder_cls, - cluster_node_limit=cluster_node_limit, + decoder_cls=self.decoder_cls, + cluster_node_limit=self.cluster_node_limit, ) return MwpfCompiledDecoder( solver, @@ -99,13 +114,14 @@ def decode_via_files( dets_b8_in_path: pathlib.Path, obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, - decoder_cls: Any = None, ) -> None: import mwpf error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - error_model, decoder_cls=decoder_cls + error_model, + decoder_cls=self.decoder_cls, + cluster_node_limit=self.cluster_node_limit, ) num_det_bytes = math.ceil(num_dets / 8) with open(dets_b8_in_path, "rb") as dets_in_f: @@ -136,44 +152,8 @@ def decode_via_files( class HyperUFDecoder(MwpfDecoder): - def compile_decoder_for_dem( - self, *, dem: "stim.DetectorErrorModel" - ) -> CompiledDecoder: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().compile_decoder_for_dem( - dem=dem, decoder_cls=mwpf.SolverSerialUnionFind - ) - - def decode_via_files( - self, - *, - num_shots: int, - num_dets: int, - num_obs: int, - dem_path: pathlib.Path, - dets_b8_in_path: pathlib.Path, - obs_predictions_b8_out_path: pathlib.Path, - tmp_dir: pathlib.Path, - ) -> None: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().decode_via_files( - num_shots=num_shots, - num_dets=num_dets, - num_obs=num_obs, - dem_path=dem_path, - dets_b8_in_path=dets_b8_in_path, - obs_predictions_b8_out_path=obs_predictions_b8_out_path, - tmp_dir=tmp_dir, - decoder_cls=mwpf.SolverSerialUnionFind, - ) + def __init__(self): + super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0) def iter_flatten_model( @@ -193,16 +173,16 @@ def _helper(m: stim.DetectorErrorModel, reps: int): _helper(instruction.body_copy(), instruction.repeat_count) elif isinstance(instruction, stim.DemInstruction): if instruction.type == "error": - dets: List[int] = [] - frames: List[int] = [] + dets: set[int] = set() + frames: set[int] = set() t: stim.DemTarget p = instruction.args_copy()[0] for t in instruction.targets_copy(): if t.is_relative_detector_id(): - dets.append(t.val + det_offset) + dets ^= {t.val + det_offset} elif t.is_logical_observable_id(): - frames.append(t.val) - handle_error(p, dets, frames) + frames ^= {t.val} + handle_error(p, list(dets), list(frames)) elif instruction.type == "shift_detectors": det_offset += instruction.targets_copy()[0] a = np.array(instruction.args_copy()) @@ -310,6 +290,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray): if decoder_cls is None: # default to the solver with highest accuracy decoder_cls = mwpf.SolverSerialJointSingleHair + elif isinstance(decoder_cls, str): + decoder_cls = getattr(mwpf, decoder_cls) return ( ( decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})