diff --git a/cebra/__init__.py b/cebra/__init__.py index b361a441..204cd2a2 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -92,11 +92,11 @@ def __getattr__(key): return CEBRA elif key == "KNNDecoder": - from cebra.integrations.sklearn.decoder import KNNDecoder + from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811 return KNNDecoder elif key == "L1LinearRegressor": - from cebra.integrations.sklearn.decoder import L1LinearRegressor + from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811 return L1LinearRegressor elif not key.startswith("_"): diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 9fa815c2..dbb2f1f5 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -22,13 +22,15 @@ """Pre-defined datasets.""" import types -from typing import List, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import numpy.typing as npt import torch import cebra.data as cebra_data +import cebra.helper as cebra_helper +from cebra.data.datatypes import Offset class TensorDataset(cebra_data.SingleSessionDataset): @@ -64,26 +66,52 @@ def __init__(self, neural: Union[torch.Tensor, npt.NDArray], continuous: Union[torch.Tensor, npt.NDArray] = None, discrete: Union[torch.Tensor, npt.NDArray] = None, - offset: int = 1, + offset: Offset = Offset(0, 1), device: str = "cpu"): super().__init__(device=device) - self.neural = self._to_tensor(neural, torch.FloatTensor).float() - self.continuous = self._to_tensor(continuous, torch.FloatTensor) - self.discrete = self._to_tensor(discrete, torch.LongTensor) + self.neural = self._to_tensor(neural, check_dtype="float").float() + self.continuous = self._to_tensor(continuous, check_dtype="float") + self.discrete = self._to_tensor(discrete, check_dtype="int") if self.continuous is None and self.discrete is None: raise ValueError( "You have to pass at least one of the arguments 'continuous' or 'discrete'." ) self.offset = offset - def _to_tensor(self, array, check_dtype=None): + def _to_tensor( + self, + array: Union[torch.Tensor, npt.NDArray], + check_dtype: Optional[Literal["int", + "float"]] = None) -> torch.Tensor: + """Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype. + + Args: + array: Array to check. + check_dtype: If not `None`, list of dtypes to which the values in `array` + must belong to. Defaults to None. + + Returns: + The `array` as a :py:class:`torch.Tensor`. + """ if array is None: return None if isinstance(array, np.ndarray): array = torch.from_numpy(array) if check_dtype is not None: - if not isinstance(array, check_dtype): - raise TypeError(f"{type(array)} instead of {check_dtype}.") + if check_dtype not in ["int", "float"]: + raise ValueError( + f"check_dtype must be 'int' or 'float', got {check_dtype}") + if (check_dtype == "int" and not cebra_helper._is_integer(array) + ) or (check_dtype == "float" and + not cebra_helper._is_floating(array)): + raise TypeError( + f"Array has type {array.dtype} instead of {check_dtype}.") + if cebra_helper._is_floating(array): + array = array.float() + if cebra_helper._is_integer(array): + # NOTE(stes): Required for standardizing number format on + # windows machines. + array = array.long() return array @property diff --git a/cebra/data/helper.py b/cebra/data/helper.py index d2a1cfe3..8582edae 100644 --- a/cebra/data/helper.py +++ b/cebra/data/helper.py @@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment: For each dataset, the data and labels to align the data on is provided. - 1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``). - 2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``. - 3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets. - 4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``. + 1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to + the labels of the reference dataset (``ref_label``) are selected and used to sample + from the dataset to align (``data``). + 2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number + of samples ``subsample``. + 3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, + on those subsampled datasets. + 4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` + to the ``ref_data``. Note: ``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number diff --git a/cebra/data/load.py b/cebra/data/load.py index ddaf8ade..6f1b86e5 100644 --- a/cebra/data/load.py +++ b/cebra/data/load.py @@ -663,7 +663,8 @@ def load( - if no key is provided, the first data structure found upon iteration of the collection will be loaded; - if a key is provided, it needs to correspond to an existing item of the collection; - if a key is provided, the data value accessed needs to be a data structure; - - the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones. + - the function loads data for only one data structure, even if the file contains more. The function can be + called again with the corresponding key to get the other ones. Args: file: The path to the given file to load, in a supported format. diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 0c575ed7..ab6c9729 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -358,7 +358,6 @@ def __post_init__(self): # here might be sub-optimal. The final behavior should be determined after # e.g. integrating the FAISS dataloader back in. super().__post_init__() - index = self.index.to(self.device) if self.conditional != "time_delta": raise NotImplementedError( @@ -368,8 +367,7 @@ def __post_init__(self): self.time_distribution = cebra.distributions.TimeContrastive( time_offset=self.time_offset, num_samples=len(self.dataset.neural), - device=self.device, - ) + device=self.device) self.behavior_distribution = cebra.distributions.TimedeltaDistribution( self.dataset.continuous_index, self.time_offset, device=self.device) diff --git a/cebra/datasets/__init__.py b/cebra/datasets/__init__.py index 76bfed3c..5716e399 100644 --- a/cebra/datasets/__init__.py +++ b/cebra/datasets/__init__.py @@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str: from cebra.datasets.monkey_reaching import * from cebra.datasets.synthetic_data import * except ModuleNotFoundError as e: - import warnings - warnings.warn(f"Could not initialize one or more datasets: {e}. " f"For using the datasets, consider installing the " f"[datasets] extension via pip.") diff --git a/cebra/datasets/allen/ca_movie.py b/cebra/datasets/allen/ca_movie.py index fa25f72a..083527ee 100644 --- a/cebra/datasets/allen/ca_movie.py +++ b/cebra/datasets/allen/ca_movie.py @@ -22,11 +22,15 @@ """Allen pseudomouse Ca dataset. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ import pathlib diff --git a/cebra/datasets/allen/ca_movie_decoding.py b/cebra/datasets/allen/ca_movie_decoding.py index 8bb164cc..aefd5d57 100644 --- a/cebra/datasets/allen/ca_movie_decoding.py +++ b/cebra/datasets/allen/ca_movie_decoding.py @@ -22,11 +22,15 @@ """Allen pseudomouse Ca decoding dataset with train/test split. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ import pathlib @@ -243,11 +247,6 @@ def _convert_to_nums(string): return pseudo_mouse - pseudo_mouse = np.vstack( - [get_neural_data(num_movie, mice) for mice in list_mice]) - - return pseudo_mouse - def __len__(self): return self.neural.size(0) diff --git a/cebra/datasets/allen/combined.py b/cebra/datasets/allen/combined.py index a05eb17c..ac1208ff 100644 --- a/cebra/datasets/allen/combined.py +++ b/cebra/datasets/allen/combined.py @@ -22,13 +22,19 @@ """Joint Allen pseudomouse Ca/Neuropixel datasets. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - *https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html - *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current Biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature Neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding + * https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html + * Siegle, Joshua H., et al. + "Survey of spiking in the mouse visual system reveals functional hierarchy." + Nature 592.7852 (2021): 86-92. """ import cebra.data diff --git a/cebra/datasets/allen/make_neuropixel.py b/cebra/datasets/allen/make_neuropixel.py index 1eabfe9f..aecdf4bf 100644 --- a/cebra/datasets/allen/make_neuropixel.py +++ b/cebra/datasets/allen/make_neuropixel.py @@ -192,11 +192,12 @@ def read_neuropixel( "intervals/natural_movie_one_presentations/start_time"][...] end_time = d[ "intervals/natural_movie_one_presentations/stop_time"][...] - timeseries = d[ - "intervals/natural_movie_one_presentations/timeseries"][...] - timeseries_index = d[ - "intervals/natural_movie_one_presentations/timeseries_index"][ - ...] + # NOTE(stes): never used. leaving here for future reference + #timeseries = d[ + # "intervals/natural_movie_one_presentations/timeseries"][...] + #timeseries_index = d[ + # "intervals/natural_movie_one_presentations/timeseries_index"][ + # ...] session_no = d["identifier"][...].item() spike_time_index = d["units/spike_times_index"][...] spike_times = d["units/spike_times"][...] @@ -266,7 +267,7 @@ def read_neuropixel( "neural": sessions_dic, "frames": session_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl", ) jl.dump( @@ -274,6 +275,6 @@ def read_neuropixel( "neural": pseudo_mice, "frames": pseudo_mice_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl", ) diff --git a/cebra/datasets/allen/single_session_ca.py b/cebra/datasets/allen/single_session_ca.py index 5a3eea4d..794de602 100644 --- a/cebra/datasets/allen/single_session_ca.py +++ b/cebra/datasets/allen/single_session_ca.py @@ -19,14 +19,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Allen single mouse dataset. +""" +Allen single mouse dataset. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current Biology 31.19 (2021): 4327-4339. + + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature Neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ import pathlib @@ -113,7 +120,7 @@ def __getitem__(self, index): "allen-movie1-ca-single-session-corrupt-{session_id}", session_id=range(len(_SINGLE_SESSION_CA)), ) -class SingleSessionAllenCa(cebra.data.SingleSessionDataset): +class SingleSessionAllenCaCorrupted(cebra.data.SingleSessionDataset): """A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus. A dataset of a single mouse 30Hz calcium events from the excitatory neurons in the primary visual cortex @@ -352,7 +359,7 @@ def __init__(self, repeat_no, split_flag): repeat_no=[9], split_flag=["train", "test"], ) -class SingleSessionAllenCaDecoding(cebra.data.SingleSessionDataset): +class SingleSessionAllenCaDecodingCorrupted(cebra.data.SingleSessionDataset): """A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus with train/test splits. A dataset of a single mouse 30Hz calcium events from the excitatory neurons diff --git a/cebra/datasets/gaussian_mixture.py b/cebra/datasets/gaussian_mixture.py index 05fd971d..48e10446 100644 --- a/cebra/datasets/gaussian_mixture.py +++ b/cebra/datasets/gaussian_mixture.py @@ -27,9 +27,12 @@ import cebra.data import cebra.io +from cebra.datasets import get_datapath from cebra.datasets import parametrize from cebra.datasets import register +_DEFAULT_DATADIR = get_datapath() + @register("continuous-gaussian-mixture") @parametrize( diff --git a/cebra/datasets/generate_synthetic_data.py b/cebra/datasets/generate_synthetic_data.py index 0fc33963..a2a8048d 100644 --- a/cebra/datasets/generate_synthetic_data.py +++ b/cebra/datasets/generate_synthetic_data.py @@ -30,7 +30,7 @@ import joblib as jl import keras import numpy as np -import poisson +import poisson as poisson_utils import scipy.stats import tensorflow as tf @@ -228,7 +228,7 @@ def refractory_poisson(x): flattened_lam = lam_true.flatten() x = np.zeros_like(flattened_lam) for i, rate in enumerate(flattened_lam): - neuron = poisson.PoissonNeuron( + neuron = poisson_utils.PoissonNeuron( spike_rate=rate * args.scale, num_repeats=1, time_interval=args.time_interval, diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index 92537b8e..05c47acb 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -160,14 +160,16 @@ def decode(self, x_train, y_train, x_test, y_test): class SingleRatTrialSplitDataset(SingleRatDataset): """A single rat hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits. - Neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat. - The behavior label is structured as 3D array consists of position, right, and left. - The neural and behavior recordings are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. + Neural data is spike counts binned into 25ms time window and the behavior is position and the running + direction (left, right) of a rat. The behavior label is structured as 3D array consists of position, + right, and left. The neural and behavior recordings are parsed into trials (a round trip from one end + of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. Args: name: The name of a rat to use. Choose among 'achilles', 'buddy', 'cicero' and 'gatsby'. split_no: The `k` for k-fold split. Choose among 0, 1, 2. - split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split). + split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test' + (all trials except test split). """ @@ -281,13 +283,16 @@ class MultipleRatsTrialSplitDataset(cebra.data.DatasetCollection): """4 rats hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits. Neural and behavior recordings of 4 rats. - For each rat, neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat. + For each rat, neural data is spike counts binned into 25ms time window and the behavior is position + and the running direction (left, right) of a rat. The behavior label is structured as 3D array consists of position, right, and left. - Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. + Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) + and the trials are split into a train, valid and test set with k=3 nested cross validation. Args: split_no: The `k` for k-fold split. Choose among 0, 1, and 2. - split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split). + split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test' + (all trials except test split). """ diff --git a/cebra/datasets/make_neuropixel.py b/cebra/datasets/make_neuropixel.py index 65029f94..431191db 100644 --- a/cebra/datasets/make_neuropixel.py +++ b/cebra/datasets/make_neuropixel.py @@ -21,16 +21,19 @@ # """Generate pseudomouse Neuropixels data. -This script generates the pseudomouse Neuropixels data for each visual cortical area from original Allen ENuropixels Brain observatory 1.1 NWB data. -We followed the units filtering used in the AllenSDK package. +This script generates the pseudomouse Neuropixels data for each visual cortical area from original +Allen ENuropixels Brain observatory 1.1 NWB data. We followed the units filtering used in the AllenSDK package. References: - *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. - *https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html + * Siegle, Joshua H., et al. + "Survey of spiking in the mouse visual system reveals functional hierarchy." + Nature 592.7852 (2021): 86-92. + * https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html """ import argparse import glob +import pathlib import h5py import joblib as jl @@ -193,11 +196,12 @@ def read_neuropixel( "intervals/natural_movie_one_presentations/start_time"][...] end_time = d[ "intervals/natural_movie_one_presentations/stop_time"][...] - timeseries = d[ - "intervals/natural_movie_one_presentations/timeseries"][...] - timeseries_index = d[ - "intervals/natural_movie_one_presentations/timeseries_index"][ - ...] + # NOTE(stes): Never used. Commenting, but leaving for future ref. + #timeseries = d[ + # "intervals/natural_movie_one_presentations/timeseries"][...] + #timeseries_index = d[ + # "intervals/natural_movie_one_presentations/timeseries_index"][ + # ...] session_no = d["identifier"][...].item() spike_time_index = d["units/spike_times_index"][...] spike_times = d["units/spike_times"][...] @@ -262,13 +266,13 @@ def read_neuropixel( "neural": sessions_dic, "frames": session_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl") jl.dump( { "neural": pseudo_mice, "frames": pseudo_mice_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl", ) diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index a07e24fd..05071b12 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -22,10 +22,16 @@ """Ephys neural and behavior data used for the monkey reaching experiment. References: - * Chowdhury, Raeed H., Joshua I. Glaser, and Lee E. Miller. "Area 2 of primary somatosensory cortex encodes kinematics of the whole arm." Elife 9 (2020). - * Chowdhury, Raeed; Miller, Lee (2022) Area2 Bump: macaque somatosensory area 2 spiking activity during reaching with perturbations (Version 0.220113.0359) [Data set]. `DANDI archive `_ - * Pei, Felix, et al. "Neural Latents Benchmark'21: Evaluating latent variable models of neural population activity." arXiv preprint arXiv:2109.04463 (2021). - + * Chowdhury, Raeed H., Joshua I. Glaser, and Lee E. Miller. + "Area 2 of primary somatosensory cortex encodes kinematics of the whole arm." + Elife 9 (2020). + * Chowdhury, Raeed; Miller, Lee (2022) + Area2 Bump: macaque somatosensory area 2 spiking activity during reaching + with perturbations (Version 0.220113.0359) [Data set]. + `DANDI archive `_ + * Pei, Felix, et al. + "Neural Latents Benchmark'21: Evaluating latent variable models of neural + population activity." arXiv preprint arXiv:2109.04463 (2021). """ import pathlib @@ -421,7 +427,7 @@ def _create_area2_dataset(): for session_type in ["active", "passive", "active-passive", "all"]: @register(f"area2-bump-pos-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV1(Area2BumpDataset): """Monkey reaching dataset with hand position labels. The dataset loads continuous x,y hand position as behavior labels. @@ -450,7 +456,7 @@ def continuous_index(self): return self.pos @register(f"area2-bump-target-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV2(Area2BumpDataset): """Monkey reaching dataset with target direction labels. The dataset loads discrete target direction (0-7) as behavior labels. @@ -477,7 +483,7 @@ def continuous_index(self): return None @register(f"area2-bump-posdir-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV3(Area2BumpDataset): """Monkey reaching dataset with hand position labels and discrete target labels. The dataset loads continuous x,y hand position and discrete target labels (0-7) @@ -520,7 +526,7 @@ def _create_area2_shuffled_dataset(): for session_type in ["active", "active-passive"]: @register(f"area2-bump-pos-{session_type}-shuffled-trial") - class Dataset(Area2BumpShuffledDataset): + class DatasetV4(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled trial type. The dataset loads the discrete binary trial type label active(0)/passive(1) @@ -548,7 +554,7 @@ def continuous_index(self): return self.pos @register(f"area2-bump-pos-{session_type}-shuffled-position") - class Dataset(Area2BumpShuffledDataset): + class DatasetV5(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled hand position. The dataset loads continuous x,y hand position in randomly shuffled order. @@ -577,7 +583,7 @@ def continuous_index(self): return self.pos_shuffled @register(f"area2-bump-target-{session_type}-shuffled") - class Dataset(Area2BumpShuffledDataset): + class DatasetV6(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled hand position. The dataset loads discrete target direction (0-7 for active and 0-15 for active-passive) diff --git a/cebra/datasets/save_dataset.py b/cebra/datasets/save_dataset.py index f5e01a62..a9b3862a 100644 --- a/cebra/datasets/save_dataset.py +++ b/cebra/datasets/save_dataset.py @@ -50,7 +50,7 @@ def save_allen_decoding_dataset(savepath=get_datapath("allen_preload/")): print(f"{savepath}/{dataname}.jl") -def save_allen_dataset(savepath=get_datapth("allen_preload/")): +def save_allen_dataset(savepath=get_datapath("allen_preload/")): """Load and save complete allen dataset for Ca. Load and save all neural and behavioral data relevant for allen decoding dataset to reduce data loading time for the experiments using the shared data. It saves Ca data for the neuron numbers (10-1000), 5 different seeds for sampling the neurons. diff --git a/cebra/grid_search.py b/cebra/grid_search.py index 14337ac0..1805c896 100644 --- a/cebra/grid_search.py +++ b/cebra/grid_search.py @@ -138,7 +138,8 @@ def fit_models(self, to fit the CEBRA models on. The models are then trained using temporal contrastive learning (CEBRA-Time). An example of a valid ``datasets`` value could be: - ``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data), "dataset3": (neural_data2, continuous_data2)}``. + ``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data), + "dataset3": (neural_data2, continuous_data2)}``. params: Dict of parameter values provided by the user, either as a single value, for fixed hyperparameter values, or with a list of values for hyperparameters to optimize. If the value is a list of a single element, the hyperparameter is considered as fixed. diff --git a/cebra/helper.py b/cebra/helper.py index 2175e6ac..93bae2b1 100644 --- a/cebra/helper.py +++ b/cebra/helper.py @@ -99,7 +99,7 @@ def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool: def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool: - """Check if the values in ``y`` are :py:class:`int`. + """Check if the values in ``y`` are :py:class:`float`. Note: There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor` @@ -118,6 +118,19 @@ def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool: y, torch.Tensor) and torch.is_floating_point(y)) +def _is_floating_or_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool: + """Check if the values in ``y`` are :py:class:`int` or :py:class:`float`. + + Args: + y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`. + + Returns: + ``True`` if ``y`` contains :py:class:`float` or :py:class:`int`. + """ + + return _is_floating(y) or _is_integer(y) + + def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]: """Return all possible dataloaders for the given dataset. @@ -152,7 +165,7 @@ def _requires_package_version(function): @wraps(function) def wrapper(*args, patched_version=None, **kwargs): - if patched_version != None: + if patched_version is not None: installed_version = pkg_resources.parse_version( patched_version) # Use the patched version if provided else: diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py index b79deb66..30af7fd4 100644 --- a/cebra/integrations/matplotlib.py +++ b/cebra/integrations/matplotlib.py @@ -35,6 +35,7 @@ import torch from cebra import CEBRA +from cebra.helper import requires_package_version def _register_colormap(): @@ -289,13 +290,13 @@ def _define_plot_dim( * If ``idx_order`` is not provided, the plot will be 3D by default. * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if only 2 dimensions - are provided, the plot will be 2D. + are provided, the plot will be 2D. If the embedding dimension is equal to 2: * If ``idx_order`` is not provided, the plot will be 2D by default. * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if 2 dimensions - are provided, the plot will be 2D. + are provided, the plot will be 2D. This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). @@ -480,8 +481,9 @@ def plot(self, **kwargs) -> matplotlib.axes.Axes: elif isinstance(self.embedding_labels, Iterable): if len(self.embedding_labels) != self.embedding.shape[0]: raise ValueError( - f"Invalid embedding labels: the labels vector should have the same number of samples as the embedding, got {len(self.embedding_labels)}, expect {self.embedding.shape[0]}." - ) + f"Invalid embedding labels: the labels vector should have the same number " + f"of samples as the embedding, got {len(self.embedding_labels)}, " + f"expected {self.embedding.shape[0]}.") if self.embedding_labels.ndim > 1: raise NotImplementedError( f"Invalid embedding labels: plotting does not support multiple sets of labels, got {self.embedding_labels.ndim}." @@ -668,8 +670,9 @@ def _to_heatmap_format( assert len(pairs) == len(values), (self.pairs.shape, len(values)) score_dict = {tuple(pair): value for pair, value in zip(pairs, values)} - if self.labels is None: - n_grid = self.score + # NOTE(stes): Never used, might be possible to remove. + #if self.labels is None: + # n_grid = self.score heatmap_values = np.zeros((len(self.labels), len(self.labels))) @@ -1012,46 +1015,53 @@ def plot_embedding( If the embedding dimension is equal or higher to 3: - * If ``idx_order`` is not provided, the plot will be 3D by default. - * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if only 2 dimensions are provided, the plot will be 2D. + - If ``idx_order`` is not provided, the plot will be 3D by default. + - If ``idx_order`` is provided, and it has 3 dimensions, the plot will be 3D; + if only 2 dimensions are provided, the plot will be 2D. If the embedding dimension is equal to 2: - * If ``idx_order`` is not provided, the plot will be 2D by default. - * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if 2 dimensions are provided, the plot will be 2D. + - If ``idx_order`` is not provided, the plot will be 2D by default. + - If ``idx_order`` is provided, and it has 3 dimensions, the plot will be 3D; + if 2 dimensions are provided, the plot will be 2D. - This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of - dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). - - The function makes use of :py:func:`matplotlib.pyplot.scatter` and parameters from that function can be provided - as part of ``kwargs``. + This assumes that the dimensions provided to ``idx_order`` are within the range of the + number of dimensions of the embedding (i.e., between 0 and + :py:attr:`cebra.CEBRA.output_dimension` -1). + The function makes use of :py:func:`matplotlib.pyplot.scatter`, and parameters from + that function can be provided as part of ``kwargs``. Args: embedding: A matrix containing the feature representation computed with CEBRA. embedding_labels: The labels used to map the data to color. It can be: - * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. - * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. + - A vector that is the same sample size as the embedding, associating a value + to each sample, either discrete or continuous. + - A string, either `time`, which will color the embedding based on temporality, + or a string that can be interpreted as an RGB(A) color, which will display + the embedding uniformly with that color. + ax: Optional axis to create the plot on. - idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D - embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those - dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation - (e.g., (2, 4, 5)). + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension + in the 3D/2D embedding. The simplest form is (0, 1, 2) or (0, 1), but one might + want to plot either those dimensions differently (e.g., (1, 0, 2)) or other + dimensions from the feature representation (e.g., (2, 4, 5)). + markersize: The marker size. alpha: The marker blending, between 0 (transparent) and 1 (opaque). - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). title: The title on top of the embedding. figsize: Figure width and height in inches. dpi: Figure resolution. - kwargs: Optional arguments to customize the plots. See :py:func:`matplotlib.pyplot.scatter` documentation for more - details on which arguments to use. + kwargs: Optional arguments to customize the plots. See :py:func:`matplotlib.pyplot.scatter` + documentation for more details on which arguments to use. Returns: The axis :py:meth:`matplotlib.axes.Axes.axis` of the plot. Example: - >>> import cebra >>> import numpy as np >>> X = np.random.uniform(0, 1, (100, 50)) @@ -1061,8 +1071,8 @@ def plot_embedding( CEBRA(max_iterations=10) >>> embedding = cebra_model.transform(X) >>> ax = cebra.plot_embedding(embedding, embedding_labels='time') - """ + return _EmbeddingPlot( embedding=embedding, embedding_labels=embedding_labels, @@ -1134,7 +1144,12 @@ def plot_consistency( >>> labels2 = np.random.uniform(0, 1, (1000, )) >>> dataset_ids = ["achilles", "buddy"] >>> # between-datasets consistency, by aligning on the labels - >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2], labels=[labels1, labels2], dataset_ids=dataset_ids, between="datasets") + >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score( + ... embeddings=[embedding1, embedding2], + ... labels=[labels1, labels2], + ... dataset_ids=dataset_ids, + ... between="datasets" + ... ) >>> ax = cebra.plot_consistency(scores, pairs, datasets, vmin=0, vmax=100) """ @@ -1153,9 +1168,6 @@ def plot_consistency( ).plot(**kwargs) -from cebra.helper import requires_package_version - - @requires_package_version(matplotlib, "3.6") def compare_models( models: List[CEBRA], diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index e60bcb8f..bbaa1de6 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -79,7 +79,8 @@ def _define_colorscale(self, cmap: str): """Specify the cmap for plotting the latent space. Args: - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). Returns: @@ -171,39 +172,41 @@ def plot_embedding_interactive( The function makes use of :py:class:`plotly.graph_objects.Scatter` and parameters from that function can be provided as part of ``kwargs``. - Args: embedding: A matrix containing the feature representation computed with CEBRA. embedding_labels: The labels used to map the data to color. It can be: - * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. - * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. + - A vector that is the same sample size as the embedding, associating a value + to each of the sample, either discrete or continuous. + - A string, either `time`, then the labels will color the embedding based on + temporality, or a string that can be interpreted as a RGB(A) color, then + the embedding will be uniformly displayed with that unique color. + axis: Optional axis to create the plot on. - idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D - embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those - dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation - (e.g., (2, 4, 5)). + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension + in the 3D/2D embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want + to plot either those dimensions differently (e.g., (1, 0, 2)) or other dimensions + from the feature representation (e.g., (2, 4, 5)). + markersize: The marker size. alpha: The marker blending, between 0 (transparent) and 1 (opaque). - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). title: The title on top of the embedding. figsize: Figure width and height in inches. dpi: Figure resolution. - kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments: - -- showlegend: Whether to show the legend or not. - -- discrete: Whether the labels are discrete or not. - -- col: The column of the subplot to plot the embedding on. - -- row: The row of the subplot to plot the embedding on. - -- template: The template to use for the plot. - - Note: showlegend can be True only if discrete is True. - - See :py:class:`plotly.graph_objects.Scatter` documentation for more - details on which arguments to use. - - Returns: - The plotly figure. - + kwargs: Optional arguments to customize the plots. This dictionary includes the following + optional arguments: + + - ``showlegend``: Whether to show the legend or not. + - ``discrete``: Whether the labels are discrete or not. + - ``col``: The column of the subplot to plot the embedding on. + - ``row``: The row of the subplot to plot the embedding on. + - ``template``: The template to use for the plot. + + Note: ``showlegend`` can be ``True`` only if ``discrete`` is ``True``. + See :py:class:`plotly.graph_objects.Scatter` documentation for more details on which + arguments to use. Example: diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 97beaaaa..a340a392 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -314,7 +314,8 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": for key, value in state.items(): setattr(cebra_, key, value) - state_and_args = {**args, **state} + #TODO(stes): unused right now + #state_and_args = {**args, **state} if not sklearn_utils.check_fitted(cebra_): raise ValueError( @@ -1358,13 +1359,13 @@ def save(self, - 'args': A dictionary of parameters used to initialize the CEBRA model. - 'state': The state of the CEBRA model, which includes various internal attributes. - 'state_dict': The state dictionary of the underlying solver used by CEBRA. - - 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn. + - 'metadata': Additional metadata about the saved model, including the backend used and + the version of CEBRA PyTorch, NumPy and scikit-learn. "torch" backend: The model is directly saved using `torch.save` with no additional information. The saved file contains the entire CEBRA model state. - Example: >>> import cebra diff --git a/cebra/integrations/sklearn/helpers.py b/cebra/integrations/sklearn/helpers.py index 9127aaa2..2d2fc627 100644 --- a/cebra/integrations/sklearn/helpers.py +++ b/cebra/integrations/sklearn/helpers.py @@ -40,7 +40,7 @@ def _get_min_max( min = float("inf") max = float("-inf") for label in labels: - if any(isinstance(l, str) for l in label): + if any(isinstance(label_element, str) for label_element in label): raise ValueError( "Invalid labels dtype, expect floats or integers, got string") min = np.min(label) if min > np.min(label) else min diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index d2a5a04f..47c2a87f 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -36,6 +36,7 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F from torch import nn @@ -212,7 +213,6 @@ def __init__(self, self.max_inverse_temperature = math.inf else: self.max_inverse_temperature = 1.0 / min_temperature - start_tempearture = float(temperature) log_inverse_temperature = torch.tensor( math.log(1.0 / float(temperature))) self.log_inverse_temperature = nn.Parameter(log_inverse_temperature) diff --git a/cebra/models/multiobjective.py b/cebra/models/multiobjective.py index da7f992e..d9393fdc 100644 --- a/cebra/models/multiobjective.py +++ b/cebra/models/multiobjective.py @@ -94,7 +94,7 @@ def is_valid(self, mode): Returns: ``True`` for a valid representation, ``False`` otherwise. """ - return mode in _ALL + return mode in _ALL # noqa: F821 def __init__( self, diff --git a/cebra/registry.py b/cebra/registry.py index be9afbd0..994fbd5c 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -115,7 +115,9 @@ def get_options( ): instance = cls.get_instance(module) if expand_parametrized: - filter_ = lambda k, v: True + + def filter_(k, v): + return True else: class _Filter(set): diff --git a/cebra/solver/supervised.py b/cebra/solver/supervised.py index f4e4f95c..54a2da3a 100644 --- a/cebra/solver/supervised.py +++ b/cebra/solver/supervised.py @@ -61,7 +61,7 @@ def fit(self, step_idx = 0 while True: for _, batch in enumerate(loader): - stats = self.step(batch) + _ = self.step(batch) self._log_checkpoint(num_steps, loader, valid_loader) step_idx += 1 if step_idx >= num_steps: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c9f9fb2f..6a7f9319 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -70,6 +70,7 @@ def test_demo(): def test_hippocampus(): pytest.skip("Outdated") + dataset = cebra.datasets.init("rat-hippocampus-single") loader = cebra.data.ContinuousDataLoader( dataset=dataset, @@ -145,7 +146,7 @@ def test_allen(): multisubject_options.extend( cebra.datasets.get_options( "rat-hippocampus-multisubjects-3fold-trial-split*")) -except: +except: # noqa: E722 options = [] @@ -385,3 +386,106 @@ def test_download_file_wrong_content_disposition(filename, url, expected_checksum=expected_checksum, location=temp_dir, file_name=filename) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_initialization(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + assert dataset.neural.shape == neural.shape + if continuous is not None: + assert dataset.continuous.shape == continuous.shape + if discrete is not None: + assert dataset.discrete.shape == discrete.shape + + +def test_tensor_dataset_invalid_initialization(): + neural = np.random.randn(100, 30) + with pytest.raises(ValueError): + cebra.data.datasets.TensorDataset(neural) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_length(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + assert len(dataset) == len(neural) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_getitem(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + index = torch.randint(0, len(dataset), (10,)) + batch = dataset[index] + assert batch.shape[0] == len(index) + assert batch.shape[1] == neural.shape[1] + + +def test_tensor_dataset_invalid_discrete_type(): + neural = np.random.randn(100, 30) + continuous = np.random.randn(100, 2) + discrete = np.random.randn(100, 2) # Invalid type: float instead of int + with pytest.raises(TypeError): + cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + + +@pytest.mark.parametrize("array, check_dtype, expected_dtype", [ + (np.random.randn(100, 30), "float", torch.float32), + (np.random.randint(0, 5, (100, 30)), "int", torch.int64), + (torch.randn(100, 30), "float", torch.float32), + (torch.randint(0, 5, (100, 30)), "int", torch.int64), + (None, None, None), +]) +def test_to_tensor(array, check_dtype, expected_dtype): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + result = dataset._to_tensor(array, check_dtype=check_dtype) + if array is None: + assert result is None + else: + assert isinstance(result, torch.Tensor) + assert result.dtype == expected_dtype + + +def test_to_tensor_invalid_dtype(): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + array = np.random.randn(100, 30) + with pytest.raises(TypeError): + dataset._to_tensor(array, check_dtype="int") + array = np.random.randint(0, 5, (100, 30)) + with pytest.raises(TypeError): + dataset._to_tensor(array, check_dtype="float") + + +def test_to_tensor_invalid_check_dtype(): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + array = np.random.randn(100, 30) + with pytest.raises(ValueError, + match="check_dtype must be 'int' or 'float', got"): + dataset._to_tensor(array, check_dtype="invalid_dtype")