Skip to content

Commit

Permalink
Merge pull request #2513 from DradeAW/astype_rounding
Browse files Browse the repository at this point in the history
Added `round` option to `recording.astype`
  • Loading branch information
alejoe91 authored Mar 13, 2024
2 parents 6d382a2 + 5f6b898 commit 62f5199
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,10 +647,10 @@ def binary_compatible_with(
# good job you pass all crucible
return True

def astype(self, dtype):
def astype(self, dtype, round: bool | None = None):
from ..preprocessing.astype import astype

return astype(self, dtype=dtype)
return astype(self, dtype=dtype, round=round)


class BaseRecordingSegment(BaseSegment):
Expand Down
15 changes: 15 additions & 0 deletions src/spikeinterface/preprocessing/astype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import numpy as np

from ..core.core_tools import define_function_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .filter import fix_dtype
Expand All @@ -11,6 +13,9 @@ class AstypeRecording(BasePreprocessor):
Converts a recording to another dtype on the fly.
For recording with an unsigned dtype, please use the `unsigned_to_signed` preprocessing function.
If `round` is True, will round the values to the nearest integer.
If `round` is None, will round in the case of float to integer conversion.
"""

name = "astype"
Expand All @@ -19,20 +24,26 @@ def __init__(
self,
recording,
dtype=None,
round: bool | None = None,
):
dtype_ = fix_dtype(recording, dtype)
BasePreprocessor.__init__(self, recording, dtype=dtype_)

if round is None:
round = np.issubdtype(dtype, np.integer)

for parent_segment in recording._recording_segments:
rec_segment = AstypeRecordingSegment(
parent_segment,
dtype,
round,
)
self.add_recording_segment(rec_segment)

self._kwargs = dict(
recording=recording,
dtype=dtype_.str,
round=round,
)


Expand All @@ -41,14 +52,18 @@ def __init__(
self,
parent_recording_segment,
dtype,
round: bool,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.dtype = dtype
self.round = round

def get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
if self.round:
np.round(traces, out=traces)
return traces.astype(self.dtype, copy=False)


Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/preprocessing/tests/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def test_astype():
traces = (rng.randn(10000, 4) * 100).astype("float32")
rec_float32 = NumpyRecording(traces, sampling_frequency=30000)
traces_int16 = traces.astype("int16")
np.testing.assert_array_equal(traces_int16, astype(rec_float32, "int16").get_traces())
np.testing.assert_array_equal(traces_int16, astype(rec_float32, "int16", round=False).get_traces())
traces_int16_rounded = traces.round().astype("int16")
np.testing.assert_array_equal(traces_int16_rounded, astype(rec_float32, "int16").get_traces())
traces_float64 = traces.astype("float64")
np.testing.assert_array_equal(traces_float64, astype(rec_float32, "float64").get_traces())

Expand Down

0 comments on commit 62f5199

Please sign in to comment.