diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 47e2ea2849..4a0c5f8eef 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -255,7 +255,7 @@ def read_zarr( The loaded extractor """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. - # for the futur SortingResult we will have this 2 fields!!! + # for the futur SortingAnalyzer we will have this 2 fields!!! root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) @@ -367,7 +367,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G # Recording def add_recording_to_zarr_group( - recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, **kwargs + recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, dtype=None, **kwargs ): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) @@ -382,7 +382,7 @@ def add_recording_to_zarr_group( zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] - dtype = zarr_kwargs.get("dtype", None) or recording.get_dtype() + dtype = recording.get_dtype() if dtype is None else dtype channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {})