diff --git a/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py index c43760fa..4d1b2cd5 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py +++ b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py @@ -31,7 +31,10 @@ def decode_shots_bit_packed( syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse) self.solver.solve(syndrome) prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])) - predictions[shot] = np.packbits(prediction, bitorder='little') + predictions[shot] = np.packbits( + np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8), + bitorder="little", + ) self.solver.clear() return predictions diff --git a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py index 915fd0b0..22536ff6 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py @@ -38,7 +38,9 @@ def decode_shots_bit_packed( bit_packed_detection_event_data: "np.ndarray", ) -> "np.ndarray": num_shots = bit_packed_detection_event_data.shape[0] - predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8) + predictions = np.zeros( + shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8 + ) import mwpf for shot in range(num_shots): @@ -58,29 +60,42 @@ def decode_shots_bit_packed( np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) ) self.solver.clear() - predictions[shot] = np.packbits(prediction, bitorder="little") + predictions[shot] = np.packbits( + np.array( + list(np.binary_repr(prediction, width=self.num_obs))[::-1], + dtype=np.uint8, + ), + bitorder="little", + ) return predictions class MwpfDecoder(Decoder): """Use MWPF to predict observables from detection events.""" - def compile_decoder_for_dem( + def __init__( self, - *, - dem: "stim.DetectorErrorModel", decoder_cls: Any = None, # decoder class used to construct the MWPF decoder. # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins` # but just provide different plugins for optimizing the primal and/or dual solutions. # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. - cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster, + ): + self.decoder_cls = decoder_cls + self.cluster_node_limit = cluster_node_limit + super().__init__() + + def compile_decoder_for_dem( + self, + *, + dem: "stim.DetectorErrorModel", ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( dem, - decoder_cls=decoder_cls, - cluster_node_limit=cluster_node_limit, + decoder_cls=self.decoder_cls, + cluster_node_limit=self.cluster_node_limit, ) return MwpfCompiledDecoder( solver, @@ -99,13 +114,14 @@ def decode_via_files( dets_b8_in_path: pathlib.Path, obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, - decoder_cls: Any = None, ) -> None: import mwpf error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - error_model, decoder_cls=decoder_cls + error_model, + decoder_cls=self.decoder_cls, + cluster_node_limit=self.cluster_node_limit, ) num_det_bytes = math.ceil(num_dets / 8) with open(dets_b8_in_path, "rb") as dets_in_f: @@ -136,44 +152,8 @@ def decode_via_files( class HyperUFDecoder(MwpfDecoder): - def compile_decoder_for_dem( - self, *, dem: "stim.DetectorErrorModel" - ) -> CompiledDecoder: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().compile_decoder_for_dem( - dem=dem, decoder_cls=mwpf.SolverSerialUnionFind - ) - - def decode_via_files( - self, - *, - num_shots: int, - num_dets: int, - num_obs: int, - dem_path: pathlib.Path, - dets_b8_in_path: pathlib.Path, - obs_predictions_b8_out_path: pathlib.Path, - tmp_dir: pathlib.Path, - ) -> None: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().decode_via_files( - num_shots=num_shots, - num_dets=num_dets, - num_obs=num_obs, - dem_path=dem_path, - dets_b8_in_path=dets_b8_in_path, - obs_predictions_b8_out_path=obs_predictions_b8_out_path, - tmp_dir=tmp_dir, - decoder_cls=mwpf.SolverSerialUnionFind, - ) + def __init__(self): + super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0) def iter_flatten_model( @@ -193,16 +173,16 @@ def _helper(m: stim.DetectorErrorModel, reps: int): _helper(instruction.body_copy(), instruction.repeat_count) elif isinstance(instruction, stim.DemInstruction): if instruction.type == "error": - dets: List[int] = [] - frames: List[int] = [] + dets: set[int] = set() + frames: set[int] = set() t: stim.DemTarget p = instruction.args_copy()[0] for t in instruction.targets_copy(): if t.is_relative_detector_id(): - dets.append(t.val + det_offset) + dets ^= {t.val + det_offset} elif t.is_logical_observable_id(): - frames.append(t.val) - handle_error(p, dets, frames) + frames ^= {t.val} + handle_error(p, list(dets), list(frames)) elif instruction.type == "shift_detectors": det_offset += instruction.targets_copy()[0] a = np.array(instruction.args_copy()) @@ -310,6 +290,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray): if decoder_cls is None: # default to the solver with highest accuracy decoder_cls = mwpf.SolverSerialJointSingleHair + elif isinstance(decoder_cls, str): + decoder_cls = getattr(mwpf, decoder_cls) return ( ( decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})