Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into batched-inference-and…
Browse files Browse the repository at this point in the history
…-padding
  • Loading branch information
stes committed Oct 27, 2024
2 parents e1b7cc7 + e652b9a commit 0eac868
Show file tree
Hide file tree
Showing 28 changed files with 354 additions and 153 deletions.
4 changes: 2 additions & 2 deletions cebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def __getattr__(key):

return CEBRA
elif key == "KNNDecoder":
from cebra.integrations.sklearn.decoder import KNNDecoder
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811

return KNNDecoder
elif key == "L1LinearRegressor":
from cebra.integrations.sklearn.decoder import L1LinearRegressor
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811

return L1LinearRegressor
elif not key.startswith("_"):
Expand Down
44 changes: 36 additions & 8 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
"""Pre-defined datasets."""

import types
from typing import List, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch

import cebra.data as cebra_data
import cebra.helper as cebra_helper
from cebra.data.datatypes import Offset


class TensorDataset(cebra_data.SingleSessionDataset):
Expand Down Expand Up @@ -64,26 +66,52 @@ def __init__(self,
neural: Union[torch.Tensor, npt.NDArray],
continuous: Union[torch.Tensor, npt.NDArray] = None,
discrete: Union[torch.Tensor, npt.NDArray] = None,
offset: int = 1,
offset: Offset = Offset(0, 1),
device: str = "cpu"):
super().__init__(device=device)
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
self.discrete = self._to_tensor(discrete, torch.LongTensor)
self.neural = self._to_tensor(neural, check_dtype="float").float()
self.continuous = self._to_tensor(continuous, check_dtype="float")
self.discrete = self._to_tensor(discrete, check_dtype="int")
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
)
self.offset = offset

def _to_tensor(self, array, check_dtype=None):
def _to_tensor(
self,
array: Union[torch.Tensor, npt.NDArray],
check_dtype: Optional[Literal["int",
"float"]] = None) -> torch.Tensor:
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
Args:
array: Array to check.
check_dtype: If not `None`, list of dtypes to which the values in `array`
must belong to. Defaults to None.
Returns:
The `array` as a :py:class:`torch.Tensor`.
"""
if array is None:
return None
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if not isinstance(array, check_dtype):
raise TypeError(f"{type(array)} instead of {check_dtype}.")
if check_dtype not in ["int", "float"]:
raise ValueError(
f"check_dtype must be 'int' or 'float', got {check_dtype}")
if (check_dtype == "int" and not cebra_helper._is_integer(array)
) or (check_dtype == "float" and
not cebra_helper._is_floating(array)):
raise TypeError(
f"Array has type {array.dtype} instead of {check_dtype}.")
if cebra_helper._is_floating(array):
array = array.float()
if cebra_helper._is_integer(array):
# NOTE(stes): Required for standardizing number format on
# windows machines.
array = array.long()
return array

@property
Expand Down
13 changes: 9 additions & 4 deletions cebra/data/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
For each dataset, the data and labels to align the data on is provided.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
the labels of the reference dataset (``ref_label``) are selected and used to sample
from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
to the ``ref_data``.
Note:
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number
Expand Down
3 changes: 2 additions & 1 deletion cebra/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ def load(
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
- if a key is provided, it needs to correspond to an existing item of the collection;
- if a key is provided, the data value accessed needs to be a data structure;
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
- the function loads data for only one data structure, even if the file contains more. The function can be
called again with the corresponding key to get the other ones.
Args:
file: The path to the given file to load, in a supported format.
Expand Down
4 changes: 1 addition & 3 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def __post_init__(self):
# here might be sub-optimal. The final behavior should be determined after
# e.g. integrating the FAISS dataloader back in.
super().__post_init__()
index = self.index.to(self.device)

if self.conditional != "time_delta":
raise NotImplementedError(
Expand All @@ -368,8 +367,7 @@ def __post_init__(self):
self.time_distribution = cebra.distributions.TimeContrastive(
time_offset=self.time_offset,
num_samples=len(self.dataset.neural),
device=self.device,
)
device=self.device)
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
self.dataset.continuous_index, self.time_offset, device=self.device)

Expand Down
2 changes: 0 additions & 2 deletions cebra/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str:
from cebra.datasets.monkey_reaching import *
from cebra.datasets.synthetic_data import *
except ModuleNotFoundError as e:
import warnings

warnings.warn(f"Could not initialize one or more datasets: {e}. "
f"For using the datasets, consider installing the "
f"[datasets] extension via pip.")
Expand Down
14 changes: 9 additions & 5 deletions cebra/datasets/allen/ca_movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
"""Allen pseudomouse Ca dataset.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
*https://github.com/zivlab/visual_drift
*http://observatory.brain-map.org/visualcoding
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
"Representational drift in the mouse visual cortex."
Current biology 31.19 (2021): 4327-4339.
* de Vries, Saskia EJ, et al.
"A large-scale standardized physiological survey reveals functional
organization of the mouse visual cortex."
Nature neuroscience 23.1 (2020): 138-151.
* https://github.com/zivlab/visual_drift
* http://observatory.brain-map.org/visualcoding
"""

import pathlib
Expand Down
19 changes: 9 additions & 10 deletions cebra/datasets/allen/ca_movie_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
"""Allen pseudomouse Ca decoding dataset with train/test split.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
*https://github.com/zivlab/visual_drift
*http://observatory.brain-map.org/visualcoding
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
"Representational drift in the mouse visual cortex."
Current biology 31.19 (2021): 4327-4339.
* de Vries, Saskia EJ, et al.
"A large-scale standardized physiological survey reveals functional
organization of the mouse visual cortex."
Nature neuroscience 23.1 (2020): 138-151.
* https://github.com/zivlab/visual_drift
* http://observatory.brain-map.org/visualcoding
"""

import pathlib
Expand Down Expand Up @@ -243,11 +247,6 @@ def _convert_to_nums(string):

return pseudo_mouse

pseudo_mouse = np.vstack(
[get_neural_data(num_movie, mice) for mice in list_mice])

return pseudo_mouse

def __len__(self):
return self.neural.size(0)

Expand Down
20 changes: 13 additions & 7 deletions cebra/datasets/allen/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@
"""Joint Allen pseudomouse Ca/Neuropixel datasets.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
*https://github.com/zivlab/visual_drift
*http://observatory.brain-map.org/visualcoding
*https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
*Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92.
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
"Representational drift in the mouse visual cortex."
Current Biology 31.19 (2021): 4327-4339.
* de Vries, Saskia EJ, et al.
"A large-scale standardized physiological survey reveals functional
organization of the mouse visual cortex."
Nature Neuroscience 23.1 (2020): 138-151.
* https://github.com/zivlab/visual_drift
* http://observatory.brain-map.org/visualcoding
* https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
* Siegle, Joshua H., et al.
"Survey of spiking in the mouse visual system reveals functional hierarchy."
Nature 592.7852 (2021): 86-92.
"""

import cebra.data
Expand Down
15 changes: 8 additions & 7 deletions cebra/datasets/allen/make_neuropixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,12 @@ def read_neuropixel(
"intervals/natural_movie_one_presentations/start_time"][...]
end_time = d[
"intervals/natural_movie_one_presentations/stop_time"][...]
timeseries = d[
"intervals/natural_movie_one_presentations/timeseries"][...]
timeseries_index = d[
"intervals/natural_movie_one_presentations/timeseries_index"][
...]
# NOTE(stes): never used. leaving here for future reference
#timeseries = d[
# "intervals/natural_movie_one_presentations/timeseries"][...]
#timeseries_index = d[
# "intervals/natural_movie_one_presentations/timeseries_index"][
# ...]
session_no = d["identifier"][...].item()
spike_time_index = d["units/spike_times_index"][...]
spike_times = d["units/spike_times"][...]
Expand Down Expand Up @@ -266,14 +267,14 @@ def read_neuropixel(
"neural": sessions_dic,
"frames": session_frames
},
Path(args.save_path) /
pathlib.Path(args.save_path) /
f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl",
)
jl.dump(
{
"neural": pseudo_mice,
"frames": pseudo_mice_frames
},
Path(args.save_path) /
pathlib.Path(args.save_path) /
f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl",
)
21 changes: 14 additions & 7 deletions cebra/datasets/allen/single_session_ca.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Allen single mouse dataset.
"""
Allen single mouse dataset.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
*https://github.com/zivlab/visual_drift
*http://observatory.brain-map.org/visualcoding
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
"Representational drift in the mouse visual cortex."
Current Biology 31.19 (2021): 4327-4339.
* de Vries, Saskia EJ, et al.
"A large-scale standardized physiological survey reveals functional
organization of the mouse visual cortex."
Nature Neuroscience 23.1 (2020): 138-151.
* https://github.com/zivlab/visual_drift
* http://observatory.brain-map.org/visualcoding
"""
import pathlib

Expand Down Expand Up @@ -113,7 +120,7 @@ def __getitem__(self, index):
"allen-movie1-ca-single-session-corrupt-{session_id}",
session_id=range(len(_SINGLE_SESSION_CA)),
)
class SingleSessionAllenCa(cebra.data.SingleSessionDataset):
class SingleSessionAllenCaCorrupted(cebra.data.SingleSessionDataset):
"""A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus.
A dataset of a single mouse 30Hz calcium events from the excitatory neurons in the primary visual cortex
Expand Down Expand Up @@ -352,7 +359,7 @@ def __init__(self, repeat_no, split_flag):
repeat_no=[9],
split_flag=["train", "test"],
)
class SingleSessionAllenCaDecoding(cebra.data.SingleSessionDataset):
class SingleSessionAllenCaDecodingCorrupted(cebra.data.SingleSessionDataset):
"""A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus with train/test splits.
A dataset of a single mouse 30Hz calcium events from the excitatory neurons
Expand Down
3 changes: 3 additions & 0 deletions cebra/datasets/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@

import cebra.data
import cebra.io
from cebra.datasets import get_datapath
from cebra.datasets import parametrize
from cebra.datasets import register

_DEFAULT_DATADIR = get_datapath()


@register("continuous-gaussian-mixture")
@parametrize(
Expand Down
4 changes: 2 additions & 2 deletions cebra/datasets/generate_synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import joblib as jl
import keras
import numpy as np
import poisson
import poisson as poisson_utils
import scipy.stats
import tensorflow as tf

Expand Down Expand Up @@ -228,7 +228,7 @@ def refractory_poisson(x):
flattened_lam = lam_true.flatten()
x = np.zeros_like(flattened_lam)
for i, rate in enumerate(flattened_lam):
neuron = poisson.PoissonNeuron(
neuron = poisson_utils.PoissonNeuron(
spike_rate=rate * args.scale,
num_repeats=1,
time_interval=args.time_interval,
Expand Down
19 changes: 12 additions & 7 deletions cebra/datasets/hippocampus.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,16 @@ def decode(self, x_train, y_train, x_test, y_test):
class SingleRatTrialSplitDataset(SingleRatDataset):
"""A single rat hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits.
Neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat.
The behavior label is structured as 3D array consists of position, right, and left.
The neural and behavior recordings are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
Neural data is spike counts binned into 25ms time window and the behavior is position and the running
direction (left, right) of a rat. The behavior label is structured as 3D array consists of position,
right, and left. The neural and behavior recordings are parsed into trials (a round trip from one end
of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
Args:
name: The name of a rat to use. Choose among 'achilles', 'buddy', 'cicero' and 'gatsby'.
split_no: The `k` for k-fold split. Choose among 0, 1, 2.
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split).
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'
(all trials except test split).
"""

Expand Down Expand Up @@ -281,13 +283,16 @@ class MultipleRatsTrialSplitDataset(cebra.data.DatasetCollection):
"""4 rats hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits.
Neural and behavior recordings of 4 rats.
For each rat, neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat.
For each rat, neural data is spike counts binned into 25ms time window and the behavior is position
and the running direction (left, right) of a rat.
The behavior label is structured as 3D array consists of position, right, and left.
Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track)
and the trials are split into a train, valid and test set with k=3 nested cross validation.
Args:
split_no: The `k` for k-fold split. Choose among 0, 1, and 2.
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split).
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'
(all trials except test split).
"""

Expand Down
Loading

0 comments on commit 0eac868

Please sign in to comment.