Skip to content

Commit

Permalink
resolve conflict with main
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed May 16, 2024
2 parents 1d03eec + cec72f4 commit 340dd12
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 41 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ dependencies = [
"numpy",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=0.2.16",
"zarr>=2.16,<2.18",
"neo>=0.13.0",
"probeinterface>=0.2.21",
"packaging",
]

[build-system]
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
channel_ids=self.sorting_analyzer.channel_ids,
unit_ids=unit_ids,
probe=self.sorting_analyzer.get_probe(),
is_scaled=self.sorting_analyzer.return_scaled,
)
else:
raise ValueError("outputs must be numpy or Templates")
Expand Down
37 changes: 26 additions & 11 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,31 @@ class ChannelsAggregationRecording(BaseRecording):
"""

def __init__(self, recording_list, renamed_channel_ids=None):

# Generate a default list of channel ids that are unique and consecutive numbers as strings.
channel_map = {}
num_all_channels = sum(rec.get_num_channels() for rec in recording_list)

num_all_channels = sum([rec.get_num_channels() for rec in recording_list])
if renamed_channel_ids is not None:
assert len(np.unique(renamed_channel_ids)) == num_all_channels, (
"'renamed_channel_ids' doesn't have the " "right size or has duplicates!"
)
assert (
len(np.unique(renamed_channel_ids)) == num_all_channels
), "'renamed_channel_ids' doesn't have the right size or has duplicates!"
channel_ids = list(renamed_channel_ids)
else:
channel_ids = list(np.arange(num_all_channels))
# Collect channel IDs from all recordings
all_channels_have_same_type = np.unique([rec.channel_ids.dtype for rec in recording_list]).size == 1
all_channel_ids_are_unique = False
if all_channels_have_same_type:
combined_ids = np.concatenate([rec.channel_ids for rec in recording_list])
all_channel_ids_are_unique = np.unique(combined_ids).size == num_all_channels

if all_channels_have_same_type and all_channel_ids_are_unique:
channel_ids = combined_ids
else:
# If IDs are not unique or not of the same type, use default as stringify IDs
default_channel_ids = [str(i) for i in range(num_all_channels)]
channel_ids = default_channel_ids

# channel map maps channel indices that are used to get traces
ch_id = 0
for r_i, recording in enumerate(recording_list):
single_channel_ids = recording.get_channel_ids()
Expand All @@ -49,7 +62,9 @@ def __init__(self, recording_list, renamed_channel_ids=None):
break

if not (ok1 and ok2 and ok3 and ok4):
raise ValueError("Sortings don't have the same sampling_frequency/num_segments/dtype/num samples")
raise ValueError(
"Recordings do not have consistent sampling frequency, number of segments, data type, or number of samples."
)

BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)

Expand Down Expand Up @@ -91,7 +106,7 @@ def __init__(self, recording_list, renamed_channel_ids=None):
self.add_recording_segment(sub_segment)

self._recordings = recording_list
self._kwargs = {"recording_list": [rec for rec in recording_list], "renamed_channel_ids": renamed_channel_ids}
self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids}

@property
def recordings(self):
Expand Down Expand Up @@ -173,11 +188,11 @@ def aggregate_channels(recording_list, renamed_channel_ids=None):
recording_list: list
List of BaseRecording objects to aggregate
renamed_channel_ids: array-like
If given, channel ids are renamed as provided. If None, unit ids are sequential integers.
If given, channel ids are renamed as provided.
Returns
-------
aggregate_recording: UnitsAggregationSorting
The aggregated sorting object
aggregate_recording: ChannelsAggregationRecording
The aggregated recording object
"""
return ChannelsAggregationRecording(recording_list, renamed_channel_ids)
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def generate_templates(
mode="ellipsoid",
):
"""
Generate some templates from the given channel positions and neuron position.s
Generate some templates from the given channel positions and neuron positions.
The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform()
and duplicates this same waveform on all channel given a simple decay law per unit.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
):
# very fast init because checks are done in load and create
self.sorting = sorting
# self.recorsding will be a property
# self.recording will be a property
self._recording = recording
self.rec_attributes = rec_attributes
self.format = format
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class Templates:
Array of unit IDs. If `None`, defaults to an array of increasing integers.
probe: Probe, default: None
A `probeinterface.Probe` object
is_scaled : bool, optional default: True
If True, it means that the templates are in uV, otherwise they are in raw ADC values.
check_for_consistent_sparsity : bool, optional default: None
When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the
structure fo the sparsity_masl.
structure of the sparsity_mask. If False, this check is skipped.
The following attributes are available after construction:
Expand All @@ -58,6 +60,7 @@ class Templates:
templates_array: np.ndarray
sampling_frequency: float
nbefore: int
is_scaled: bool = True

sparsity_mask: np.ndarray = None
channel_ids: np.ndarray = None
Expand Down Expand Up @@ -193,6 +196,7 @@ def to_dict(self):
"unit_ids": self.unit_ids,
"sampling_frequency": self.sampling_frequency,
"nbefore": self.nbefore,
"is_scaled": self.is_scaled,
"probe": self.probe.to_dict() if self.probe is not None else None,
}

Expand All @@ -205,6 +209,7 @@ def from_dict(cls, data):
unit_ids=np.asarray(data["unit_ids"]),
sampling_frequency=data["sampling_frequency"],
nbefore=data["nbefore"],
is_scaled=data["is_scaled"],
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

Expand Down Expand Up @@ -238,6 +243,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore
zarr_group.attrs["is_scaled"] = self.is_scaled

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)
Expand Down
20 changes: 14 additions & 6 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_template_amplitudes(
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
return_scaled: bool = True,
abs_value: bool = True,
):
"""
Get amplitude per channel for each unit.
Expand All @@ -73,6 +74,8 @@ def get_template_amplitudes(
"at_index": take value at spike index
return_scaled: bool, default True
The amplitude is scaled or not.
abs_value: bool = True
Whether the extremum amplitude should be returned as an absolute value or not
Returns
-------
Expand All @@ -96,17 +99,18 @@ def get_template_amplitudes(
if peak_sign == "both":
values = np.max(np.abs(template), axis=0)
elif peak_sign == "neg":
values = -np.min(template, axis=0)
values = np.min(template, axis=0)
elif peak_sign == "pos":
values = np.max(template, axis=0)
elif mode == "at_index":
if peak_sign == "both":
values = np.abs(template[before, :])
elif peak_sign == "neg":
values = -template[before, :]
elif peak_sign == "pos":
elif peak_sign in ["neg", "pos"]:
values = template[before, :]

if abs_value:
values = np.abs(values)

peak_values[unit_id] = values

return peak_values
Expand Down Expand Up @@ -160,7 +164,7 @@ def get_template_extremum_channel(
extremum_channels_id = {}
extremum_channels_index = {}
for unit_id in unit_ids:
max_ind = np.argmax(peak_values[unit_id])
max_ind = np.argmax(np.abs(peak_values[unit_id]))
extremum_channels_id[unit_id] = channel_ids[max_ind]
extremum_channels_index[unit_id] = max_ind

Expand Down Expand Up @@ -227,6 +231,7 @@ def get_template_extremum_amplitude(
templates_or_sorting_analyzer,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "at_index",
abs_value: bool = True,
):
"""
Computes amplitudes on the best channel.
Expand All @@ -241,6 +246,9 @@ def get_template_extremum_amplitude(
Where the amplitude is computed
"extremum": max or min
"at_index": take value at spike index
abs_value: bool = True
Whether the extremum amplitude should be returned as an absolute value or not
Returns
-------
Expand All @@ -260,7 +268,7 @@ def get_template_extremum_amplitude(
return_scaled = True

extremum_amplitudes = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value
)

unit_amplitudes = {}
Expand Down
49 changes: 42 additions & 7 deletions src/spikeinterface/core/tests/test_channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_channelsaggregationrecording():

# test num channels
recording_agg = aggregate_channels([recording1, recording2, recording3])
print(recording_agg)
assert len(recording_agg.get_channel_ids()) == 3 * num_channels

assert np.allclose(recording_agg.get_times(0), recording1.get_times(0))
Expand All @@ -37,21 +36,21 @@ def test_channelsaggregationrecording():
traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg)
traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg)

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[channel_ids[1]], segment_index=seg))
assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[num_channels + channel_ids[0]], segment_index=seg)
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[2 * num_channels + channel_ids[2]], segment_index=seg)
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
traces2 = recording2.get_traces(segment_index=seg)
traces3 = recording3.get_traces(segment_index=seg)

assert np.allclose(traces1, recording_agg.get_traces(channel_ids=[0, 1, 2], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=[3, 4, 5], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=[6, 7, 8], segment_index=seg))
assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg))

# test rename channels
renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)]
Expand Down Expand Up @@ -81,5 +80,41 @@ def test_channelsaggregationrecording():
print(recording_agg_prop.get_property("brain_area"))


def test_channel_aggregation_preserve_ids():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=["d", "e"])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"]


def test_channel_aggregation_does_not_preserve_ids_if_not_unique():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=["a", "b"])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"]


def test_channel_aggregation_does_not_preserve_ids_not_the_same_type():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=[1, 2])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"]


if __name__ == "__main__":
test_channelsaggregationrecording()
1 change: 1 addition & 0 deletions src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):
sparsity_mask=None,
channel_ids=sorting_analyzer.channel_ids,
unit_ids=sorting_analyzer.unit_ids,
is_scaled=sorting_analyzer.return_scaled,
)
return templates

Expand Down
9 changes: 6 additions & 3 deletions src/spikeinterface/extractors/neoextractors/spikegadgets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from pathlib import Path

import probeinterface
import packaging

import packaging.version
import probeinterface
from spikeinterface.core.core_tools import define_function_from_class

from .neobaseextractor import NeoBaseRecordingExtractor
Expand Down Expand Up @@ -38,7 +39,9 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None
)
self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id))

probegroup = probeinterface.read_spikegadgets(file_path, raise_error=False)
probegroup = None # TODO remove once probeinterface is updated to 0.2.22 in the pyproject.toml
if packaging.version.parse(probeinterface.__version__) > packaging.version.parse("0.2.21"):
probegroup = probeinterface.read_spikegadgets(file_path, raise_error=False)

if probegroup is not None:
self.set_probes(probegroup, in_place=True)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def generate_drifting_recording(
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

drifting_templates = DriftingTemplates.from_static(templates)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def make_some_templates():
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

return templates
Expand Down
16 changes: 9 additions & 7 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from operator import is_

from .si_based import ComponentsBasedSorter

Expand Down Expand Up @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

templates = Templates(
templates_array,
sampling_frequency,
nbefore,
None,
recording_w.channel_ids,
unit_ids,
recording_w.get_probe(),
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording_w.channel_ids,
unit_ids=unit_ids,
probe=recording_w.get_probe(),
is_scaled=False,
)

sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"])
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
probe=recording_for_peeler.get_probe(),
is_scaled=False,
)

# TODO : try other methods for sparsity
Expand Down
Loading

0 comments on commit 340dd12

Please sign in to comment.