From 0e2312a81f42282ef321c8ddd2ecac30ec9a33cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Sat, 26 Oct 2024 15:54:35 +0200 Subject: [PATCH 1/4] Extend type checking to all float datatypes (#166) * Make the type checking less sensitive * Use existing type checking functions * Add better typing * Update cebra/data/datasets.py * Update cebra/data/datasets.py --------- Co-authored-by: Steffen Schneider Co-authored-by: Mackenzie Mathis --- cebra/data/datasets.py | 31 ++++++++++++++++++++++++------- cebra/helper.py | 15 ++++++++++++++- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 0b7f191d..b3c7015a 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -24,7 +24,7 @@ import abc import collections import types -from typing import List, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import literate_dataclasses as dataclasses import numpy as np @@ -74,23 +74,40 @@ def __init__(self, offset: int = 1, device: str = "cpu"): super().__init__(device=device) - self.neural = self._to_tensor(neural, torch.FloatTensor).float() - self.continuous = self._to_tensor(continuous, torch.FloatTensor) - self.discrete = self._to_tensor(discrete, torch.LongTensor) + self.neural = self._to_tensor(neural, check_dtype="float").float() + self.continuous = self._to_tensor(continuous, + check_dtype="float") + self.discrete = self._to_tensor(discrete, check_dtype="integer") if self.continuous is None and self.discrete is None: raise ValueError( "You have to pass at least one of the arguments 'continuous' or 'discrete'." ) self.offset = offset - def _to_tensor(self, array, check_dtype=None): + def _to_tensor( + self, + array: Union[torch.Tensor, npt.NDArray], + check_dtype: Optional[Literal["int", + "float"]] = None) -> torch.Tensor: + """Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype. + + Args: + array: Array to check. + check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array` + must belong to. Defaults to None. + + Returns: + The `array` as a :py:class:`torch.Tensor`. + """ if array is None: return None if isinstance(array, np.ndarray): array = torch.from_numpy(array) if check_dtype is not None: - if not isinstance(array, check_dtype): - raise TypeError(f"{type(array)} instead of {check_dtype}.") + if (check_dtype == "int" and not cebra.helper._is_integer(array) + ) or (check_dtype == "float" and + not cebra.helper._is_floating(array)): + raise TypeError(f"Array has type {array.dtype} instead of {check_dtype}.") return array @property diff --git a/cebra/helper.py b/cebra/helper.py index 2175e6ac..8e9557e5 100644 --- a/cebra/helper.py +++ b/cebra/helper.py @@ -99,7 +99,7 @@ def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool: def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool: - """Check if the values in ``y`` are :py:class:`int`. + """Check if the values in ``y`` are :py:class:`float`. Note: There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor` @@ -118,6 +118,19 @@ def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool: y, torch.Tensor) and torch.is_floating_point(y)) +def _is_floating_or_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool: + """Check if the values in ``y`` are :py:class:`int` or :py:class:`float`. + + Args: + y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`. + + Returns: + ``True`` if ``y`` contains :py:class:`float` or :py:class:`int`. + """ + + return _is_floating(y) or _is_integer(y) + + def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]: """Return all possible dataloaders for the given dataset. From 60080522b0538fd76e1496061c31f8cc95a9f5b9 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 26 Oct 2024 23:52:21 +0200 Subject: [PATCH 2/4] Fix linting issues (#185) * Apply auto-fixes * Fix issues in allen datasets - lines too long - unused variables - missing imports - duplicate names for classes * Fix line length+minor issues in plot integr. * Fix linting issues in sklearn integration * Fix linting issues in datasets - missing paths - long lines - unused variables - typos * Fix minor linting issues * Fix docstrings * fix formatting issue in docstring * Fix plotly docstrings * Fix missing import --- cebra/__init__.py | 10 +-- cebra/__main__.py | 4 -- cebra/config.py | 1 - cebra/data/base.py | 3 - cebra/data/datasets.py | 18 ++--- cebra/data/datatypes.py | 3 - cebra/data/helper.py | 23 +++--- cebra/data/load.py | 3 +- cebra/data/multi_session.py | 2 - cebra/data/single_session.py | 11 +-- cebra/datasets/__init__.py | 2 - cebra/datasets/allen/ca_movie.py | 18 ++--- cebra/datasets/allen/ca_movie_decoding.py | 24 +++---- cebra/datasets/allen/combined.py | 40 +++++------ cebra/datasets/allen/make_neuropixel.py | 17 +++-- cebra/datasets/allen/neuropixel_movie.py | 14 +--- .../allen/neuropixel_movie_decoding.py | 8 --- cebra/datasets/allen/single_session_ca.py | 29 ++++---- cebra/datasets/gaussian_mixture.py | 5 +- cebra/datasets/generate_synthetic_data.py | 5 +- cebra/datasets/hippocampus.py | 21 +++--- cebra/datasets/make_neuropixel.py | 27 +++---- cebra/datasets/monkey_reaching.py | 31 ++++---- cebra/datasets/save_dataset.py | 2 +- cebra/distributions/base.py | 3 +- cebra/distributions/continuous.py | 5 +- cebra/distributions/index.py | 7 +- cebra/distributions/mixed.py | 1 - cebra/grid_search.py | 3 +- cebra/helper.py | 2 +- cebra/integrations/deeplabcut.py | 2 +- cebra/integrations/matplotlib.py | 72 +++++++++++-------- cebra/integrations/plotly.py | 51 ++++++------- cebra/integrations/sklearn/cebra.py | 38 +++++----- cebra/integrations/sklearn/helpers.py | 4 +- cebra/integrations/sklearn/metrics.py | 12 ++-- cebra/models/criterions.py | 4 +- cebra/models/model.py | 2 - cebra/models/multiobjective.py | 2 +- cebra/models/projector.py | 2 +- cebra/registry.py | 4 +- cebra/solver/base.py | 5 +- cebra/solver/multi_session.py | 3 - cebra/solver/single_session.py | 3 - cebra/solver/supervised.py | 10 +-- 45 files changed, 251 insertions(+), 305 deletions(-) diff --git a/cebra/__init__.py b/cebra/__init__.py index fd4cf58c..204cd2a2 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -33,7 +33,7 @@ from cebra.integrations.sklearn.decoder import L1LinearRegressor is_sklearn_available = True -except ImportError as e: +except ImportError: # silently fail for now pass @@ -42,7 +42,7 @@ from cebra.integrations.matplotlib import * is_matplotlib_available = True -except ImportError as e: +except ImportError: # silently fail for now pass @@ -51,7 +51,7 @@ from cebra.integrations.plotly import * is_plotly_available = True -except ImportError as e: +except ImportError: # silently fail for now pass @@ -92,11 +92,11 @@ def __getattr__(key): return CEBRA elif key == "KNNDecoder": - from cebra.integrations.sklearn.decoder import KNNDecoder + from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811 return KNNDecoder elif key == "L1LinearRegressor": - from cebra.integrations.sklearn.decoder import L1LinearRegressor + from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811 return L1LinearRegressor elif not key.startswith("_"): diff --git a/cebra/__main__.py b/cebra/__main__.py index 6c7c18bf..4ba66993 100644 --- a/cebra/__main__.py +++ b/cebra/__main__.py @@ -27,11 +27,7 @@ import argparse import sys -import numpy as np -import torch - import cebra -import cebra.distributions as cebra_distr def train(parser, kwargs): diff --git a/cebra/config.py b/cebra/config.py index ba6e3922..a960721f 100644 --- a/cebra/config.py +++ b/cebra/config.py @@ -21,7 +21,6 @@ # import argparse import json -from dataclasses import MISSING from typing import Literal, Optional import literate_dataclasses as dataclasses diff --git a/cebra/data/base.py b/cebra/data/base.py index d2ee47b5..4fa7ba6c 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -22,11 +22,8 @@ """Base classes for datasets and loaders.""" import abc -import collections -from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data.assets as cebra_data_assets diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index b3c7015a..ecfc31ee 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -21,21 +21,15 @@ # """Pre-defined datasets.""" -import abc -import collections import types from typing import List, Literal, Optional, Tuple, Union -import literate_dataclasses as dataclasses import numpy as np import numpy.typing as npt import torch -from numpy.typing import NDArray import cebra.data as cebra_data -import cebra.distributions -from cebra.data.datatypes import Batch -from cebra.data.datatypes import BatchIndex +import cebra.helper as cebra_helper class TensorDataset(cebra_data.SingleSessionDataset): @@ -75,8 +69,7 @@ def __init__(self, device: str = "cpu"): super().__init__(device=device) self.neural = self._to_tensor(neural, check_dtype="float").float() - self.continuous = self._to_tensor(continuous, - check_dtype="float") + self.continuous = self._to_tensor(continuous, check_dtype="float") self.discrete = self._to_tensor(discrete, check_dtype="integer") if self.continuous is None and self.discrete is None: raise ValueError( @@ -104,10 +97,11 @@ def _to_tensor( if isinstance(array, np.ndarray): array = torch.from_numpy(array) if check_dtype is not None: - if (check_dtype == "int" and not cebra.helper._is_integer(array) + if (check_dtype == "int" and not cebra_helper._is_integer(array) ) or (check_dtype == "float" and - not cebra.helper._is_floating(array)): - raise TypeError(f"Array has type {array.dtype} instead of {check_dtype}.") + not cebra_helper._is_floating(array)): + raise TypeError( + f"Array has type {array.dtype} instead of {check_dtype}.") return array @property diff --git a/cebra/data/datatypes.py b/cebra/data/datatypes.py index 11583909..4b2ac8a2 100644 --- a/cebra/data/datatypes.py +++ b/cebra/data/datatypes.py @@ -20,9 +20,6 @@ # limitations under the License. # import collections -from typing import Tuple - -import torch __all__ = ["Batch", "BatchIndex", "Offset"] diff --git a/cebra/data/helper.py b/cebra/data/helper.py index c324a80f..8582edae 100644 --- a/cebra/data/helper.py +++ b/cebra/data/helper.py @@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment: For each dataset, the data and labels to align the data on is provided. - 1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``). - 2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``. - 3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets. - 4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``. + 1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to + the labels of the reference dataset (``ref_label``) are selected and used to sample + from the dataset to align (``data``). + 2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number + of samples ``subsample``. + 3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, + on those subsampled datasets. + 4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` + to the ``ref_data``. Note: ``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number @@ -181,14 +186,14 @@ def fit( elif ref_data.shape[0] == data.shape[0] and (ref_label is None or label is None): raise ValueError( - f"Missing labels: the data to align are the same shape but you provided only " - f"one of the sets of labels. Either provide both the reference and alignment " - f"labels or none.") + "Missing labels: the data to align are the same shape but you provided only " + "one of the sets of labels. Either provide both the reference and alignment " + "labels or none.") else: if ref_label is None or label is None: raise ValueError( - f"Missing labels: the data to align are not the same shape, " - f"provide labels to align the data and reference data.") + "Missing labels: the data to align are not the same shape, " + "provide labels to align the data and reference data.") if len(ref_label.shape) == 1: ref_label = np.expand_dims(ref_label, axis=1) diff --git a/cebra/data/load.py b/cebra/data/load.py index ddaf8ade..6f1b86e5 100644 --- a/cebra/data/load.py +++ b/cebra/data/load.py @@ -663,7 +663,8 @@ def load( - if no key is provided, the first data structure found upon iteration of the collection will be loaded; - if a key is provided, it needs to correspond to an existing item of the collection; - if a key is provided, the data value accessed needs to be a data structure; - - the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones. + - the function loads data for only one data structure, even if the file contains more. The function can be + called again with the corresponding key to get the other ones. Args: file: The path to the given file to load, in a supported format. diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 7bf225a0..ddcc0fa8 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -22,11 +22,9 @@ """Datasets and loaders for multi-session training.""" import abc -import collections from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data as cebra_data diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index c27b10f5..7802b787 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -26,12 +26,9 @@ """ import abc -import collections import warnings -from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data as cebra_data @@ -353,18 +350,16 @@ def __post_init__(self): # here might be sub-optimal. The final behavior should be determined after # e.g. integrating the FAISS dataloader back in. super().__post_init__() - index = self.index.to(self.device) if self.conditional != "time_delta": raise NotImplementedError( - f"Hybrid training is currently only implemented using the ``time_delta`` " - f"continual distribution.") + "Hybrid training is currently only implemented using the ``time_delta`` " + "continual distribution.") self.time_distribution = cebra.distributions.TimeContrastive( time_offset=self.time_offset, num_samples=len(self.dataset.neural), - device=self.device, - ) + device=self.device) self.behavior_distribution = cebra.distributions.TimedeltaDistribution( self.dataset.continuous_index, self.time_offset, device=self.device) diff --git a/cebra/datasets/__init__.py b/cebra/datasets/__init__.py index 76bfed3c..5716e399 100644 --- a/cebra/datasets/__init__.py +++ b/cebra/datasets/__init__.py @@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str: from cebra.datasets.monkey_reaching import * from cebra.datasets.synthetic_data import * except ModuleNotFoundError as e: - import warnings - warnings.warn(f"Could not initialize one or more datasets: {e}. " f"For using the datasets, consider installing the " f"[datasets] extension via pip.") diff --git a/cebra/datasets/allen/ca_movie.py b/cebra/datasets/allen/ca_movie.py index f11e5e93..083527ee 100644 --- a/cebra/datasets/allen/ca_movie.py +++ b/cebra/datasets/allen/ca_movie.py @@ -22,18 +22,19 @@ """Allen pseudomouse Ca dataset. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np import pandas as pd @@ -46,7 +47,6 @@ import cebra.data from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/ca_movie_decoding.py b/cebra/datasets/allen/ca_movie_decoding.py index 12d6cc64..aefd5d57 100644 --- a/cebra/datasets/allen/ca_movie_decoding.py +++ b/cebra/datasets/allen/ca_movie_decoding.py @@ -22,18 +22,19 @@ """Allen pseudomouse Ca decoding dataset with train/test split. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np import pandas as pd @@ -41,12 +42,10 @@ import torch from numpy.random import Generator from numpy.random import PCG64 -from sklearn.decomposition import PCA import cebra.data from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS from cebra.datasets.allen import SEEDS_DISJOINT @@ -248,11 +247,6 @@ def _convert_to_nums(string): return pseudo_mouse - pseudo_mouse = np.vstack( - [get_neural_data(num_movie, mice) for mice in list_mice]) - - return pseudo_mouse - def __len__(self): return self.neural.size(0) diff --git a/cebra/datasets/allen/combined.py b/cebra/datasets/allen/combined.py index bfaca9b3..ac1208ff 100644 --- a/cebra/datasets/allen/combined.py +++ b/cebra/datasets/allen/combined.py @@ -22,31 +22,23 @@ """Joint Allen pseudomouse Ca/Neuropixel datasets. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding - *https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html - *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. - + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current Biology 31.19 (2021): 4327-4339. + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature Neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding + * https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html + * Siegle, Joshua H., et al. + "Survey of spiking in the mouse visual system reveals functional hierarchy." + Nature 592.7852 (2021): 86-92. """ -import glob -import hashlib - -import h5py -import joblib -import numpy as np -import pandas as pd -import scipy.io -import torch -from numpy.random import Generator -from numpy.random import PCG64 -from sklearn.decomposition import PCA - import cebra.data from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie from cebra.datasets.allen import ca_movie_decoding from cebra.datasets.allen import neuropixel_movie @@ -80,7 +72,7 @@ def __init__(self, num_neurons=1000, seed=111, area="VISp"): ) def __repr__(self): - return f"CaNeuropixelDataset" + return "CaNeuropixelDataset" @parametrize( @@ -117,7 +109,7 @@ def __init__(self, ) def __repr__(self): - return f"CaNeuropixelMovieOneCorticesDataset" + return "CaNeuropixelMovieOneCorticesDataset" @parametrize( @@ -152,4 +144,4 @@ def __init__(self, group, num_neurons, seed, cortex, split_flag="train"): ) def __repr__(self): - return f"CaNeuropixelMovieOneCorticesDisjointDataset" + return "CaNeuropixelMovieOneCorticesDisjointDataset" diff --git a/cebra/datasets/allen/make_neuropixel.py b/cebra/datasets/allen/make_neuropixel.py index 5c0568b7..aecdf4bf 100644 --- a/cebra/datasets/allen/make_neuropixel.py +++ b/cebra/datasets/allen/make_neuropixel.py @@ -31,14 +31,12 @@ """ import argparse -import glob import pathlib import h5py import joblib as jl import numpy as np import numpy.typing as npt -import pandas as pd from cebra.datasets import get_datapath @@ -194,11 +192,12 @@ def read_neuropixel( "intervals/natural_movie_one_presentations/start_time"][...] end_time = d[ "intervals/natural_movie_one_presentations/stop_time"][...] - timeseries = d[ - "intervals/natural_movie_one_presentations/timeseries"][...] - timeseries_index = d[ - "intervals/natural_movie_one_presentations/timeseries_index"][ - ...] + # NOTE(stes): never used. leaving here for future reference + #timeseries = d[ + # "intervals/natural_movie_one_presentations/timeseries"][...] + #timeseries_index = d[ + # "intervals/natural_movie_one_presentations/timeseries_index"][ + # ...] session_no = d["identifier"][...].item() spike_time_index = d["units/spike_times_index"][...] spike_times = d["units/spike_times"][...] @@ -268,7 +267,7 @@ def read_neuropixel( "neural": sessions_dic, "frames": session_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl", ) jl.dump( @@ -276,6 +275,6 @@ def read_neuropixel( "neural": pseudo_mice, "frames": pseudo_mice_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl", ) diff --git a/cebra/datasets/allen/neuropixel_movie.py b/cebra/datasets/allen/neuropixel_movie.py index 51011407..f9b9c3ea 100644 --- a/cebra/datasets/allen/neuropixel_movie.py +++ b/cebra/datasets/allen/neuropixel_movie.py @@ -26,24 +26,12 @@ *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. """ -import glob -import hashlib import pathlib -import h5py import joblib -import numpy as np -import pandas as pd -import scipy.io -import torch -from numpy.random import Generator -from numpy.random import PCG64 -from sklearn.decomposition import PCA - -import cebra.data + from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/neuropixel_movie_decoding.py b/cebra/datasets/allen/neuropixel_movie_decoding.py index a99f367d..4ff1ebc2 100644 --- a/cebra/datasets/allen/neuropixel_movie_decoding.py +++ b/cebra/datasets/allen/neuropixel_movie_decoding.py @@ -26,25 +26,17 @@ *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np -import pandas as pd -import scipy.io import torch from numpy.random import Generator from numpy.random import PCG64 -from sklearn.decomposition import PCA import cebra.data -from cebra.datasets import allen from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie_decoding from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/single_session_ca.py b/cebra/datasets/allen/single_session_ca.py index f207a1bc..794de602 100644 --- a/cebra/datasets/allen/single_session_ca.py +++ b/cebra/datasets/allen/single_session_ca.py @@ -19,34 +19,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Allen single mouse dataset. +""" +Allen single mouse dataset. References: - *Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339. - *de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151. - *https://github.com/zivlab/visual_drift - *http://observatory.brain-map.org/visualcoding + * Deitch, Daniel, Alon Rubin, and Yaniv Ziv. + "Representational drift in the mouse visual cortex." + Current Biology 31.19 (2021): 4327-4339. + + * de Vries, Saskia EJ, et al. + "A large-scale standardized physiological survey reveals functional + organization of the mouse visual cortex." + Nature Neuroscience 23.1 (2020): 138-151. + * https://github.com/zivlab/visual_drift + * http://observatory.brain-map.org/visualcoding """ -import glob -import hashlib import pathlib -import h5py -import joblib import numpy as np -import pandas as pd import scipy.io import torch -from numpy.random import Generator -from numpy.random import PCG64 from sklearn.decomposition import PCA import cebra.data from cebra.datasets import get_datapath from cebra.datasets import init from cebra.datasets import parametrize -from cebra.datasets import register _DEFAULT_DATADIR = get_datapath() @@ -121,7 +120,7 @@ def __getitem__(self, index): "allen-movie1-ca-single-session-corrupt-{session_id}", session_id=range(len(_SINGLE_SESSION_CA)), ) -class SingleSessionAllenCa(cebra.data.SingleSessionDataset): +class SingleSessionAllenCaCorrupted(cebra.data.SingleSessionDataset): """A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus. A dataset of a single mouse 30Hz calcium events from the excitatory neurons in the primary visual cortex @@ -360,7 +359,7 @@ def __init__(self, repeat_no, split_flag): repeat_no=[9], split_flag=["train", "test"], ) -class SingleSessionAllenCaDecoding(cebra.data.SingleSessionDataset): +class SingleSessionAllenCaDecodingCorrupted(cebra.data.SingleSessionDataset): """A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus with train/test splits. A dataset of a single mouse 30Hz calcium events from the excitatory neurons diff --git a/cebra/datasets/gaussian_mixture.py b/cebra/datasets/gaussian_mixture.py index f5508838..48e10446 100644 --- a/cebra/datasets/gaussian_mixture.py +++ b/cebra/datasets/gaussian_mixture.py @@ -20,12 +20,9 @@ # limitations under the License. # import pathlib -from typing import Tuple import joblib as jl -import literate_dataclasses as dataclasses import numpy as np -import sklearn import torch import cebra.data @@ -34,6 +31,8 @@ from cebra.datasets import parametrize from cebra.datasets import register +_DEFAULT_DATADIR = get_datapath() + @register("continuous-gaussian-mixture") @parametrize( diff --git a/cebra/datasets/generate_synthetic_data.py b/cebra/datasets/generate_synthetic_data.py index 8a243d6d..a2a8048d 100644 --- a/cebra/datasets/generate_synthetic_data.py +++ b/cebra/datasets/generate_synthetic_data.py @@ -26,12 +26,11 @@ """ import argparse import pathlib -import sys import joblib as jl import keras import numpy as np -import poisson +import poisson as poisson_utils import scipy.stats import tensorflow as tf @@ -229,7 +228,7 @@ def refractory_poisson(x): flattened_lam = lam_true.flatten() x = np.zeros_like(flattened_lam) for i, rate in enumerate(flattened_lam): - neuron = poisson.PoissonNeuron( + neuron = poisson_utils.PoissonNeuron( spike_rate=rate * args.scale, num_repeats=1, time_interval=args.time_interval, diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index a32209a3..05c47acb 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -31,12 +31,10 @@ """ -import hashlib import pathlib import joblib import numpy as np -import scipy.io import sklearn.model_selection import sklearn.neighbors import torch @@ -162,14 +160,16 @@ def decode(self, x_train, y_train, x_test, y_test): class SingleRatTrialSplitDataset(SingleRatDataset): """A single rat hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits. - Neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat. - The behavior label is structured as 3D array consists of position, right, and left. - The neural and behavior recordings are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. + Neural data is spike counts binned into 25ms time window and the behavior is position and the running + direction (left, right) of a rat. The behavior label is structured as 3D array consists of position, + right, and left. The neural and behavior recordings are parsed into trials (a round trip from one end + of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. Args: name: The name of a rat to use. Choose among 'achilles', 'buddy', 'cicero' and 'gatsby'. split_no: The `k` for k-fold split. Choose among 0, 1, 2. - split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split). + split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test' + (all trials except test split). """ @@ -283,13 +283,16 @@ class MultipleRatsTrialSplitDataset(cebra.data.DatasetCollection): """4 rats hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits. Neural and behavior recordings of 4 rats. - For each rat, neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat. + For each rat, neural data is spike counts binned into 25ms time window and the behavior is position + and the running direction (left, right) of a rat. The behavior label is structured as 3D array consists of position, right, and left. - Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation. + Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) + and the trials are split into a train, valid and test set with k=3 nested cross validation. Args: split_no: The `k` for k-fold split. Choose among 0, 1, and 2. - split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split). + split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test' + (all trials except test split). """ diff --git a/cebra/datasets/make_neuropixel.py b/cebra/datasets/make_neuropixel.py index 7c097f38..431191db 100644 --- a/cebra/datasets/make_neuropixel.py +++ b/cebra/datasets/make_neuropixel.py @@ -21,22 +21,24 @@ # """Generate pseudomouse Neuropixels data. -This script generates the pseudomouse Neuropixels data for each visual cortical area from original Allen ENuropixels Brain observatory 1.1 NWB data. -We followed the units filtering used in the AllenSDK package. +This script generates the pseudomouse Neuropixels data for each visual cortical area from original +Allen ENuropixels Brain observatory 1.1 NWB data. We followed the units filtering used in the AllenSDK package. References: - *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. - *https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html + * Siegle, Joshua H., et al. + "Survey of spiking in the mouse visual system reveals functional hierarchy." + Nature 592.7852 (2021): 86-92. + * https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html """ import argparse import glob +import pathlib import h5py import joblib as jl import numpy as np import numpy.typing as npt -import pandas as pd def _filter_units( @@ -194,11 +196,12 @@ def read_neuropixel( "intervals/natural_movie_one_presentations/start_time"][...] end_time = d[ "intervals/natural_movie_one_presentations/stop_time"][...] - timeseries = d[ - "intervals/natural_movie_one_presentations/timeseries"][...] - timeseries_index = d[ - "intervals/natural_movie_one_presentations/timeseries_index"][ - ...] + # NOTE(stes): Never used. Commenting, but leaving for future ref. + #timeseries = d[ + # "intervals/natural_movie_one_presentations/timeseries"][...] + #timeseries_index = d[ + # "intervals/natural_movie_one_presentations/timeseries_index"][ + # ...] session_no = d["identifier"][...].item() spike_time_index = d["units/spike_times_index"][...] spike_times = d["units/spike_times"][...] @@ -263,13 +266,13 @@ def read_neuropixel( "neural": sessions_dic, "frames": session_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl") jl.dump( { "neural": pseudo_mice, "frames": pseudo_mice_frames }, - Path(args.save_path) / + pathlib.Path(args.save_path) / f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl", ) diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 23fc5a6c..05071b12 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -22,20 +22,23 @@ """Ephys neural and behavior data used for the monkey reaching experiment. References: - * Chowdhury, Raeed H., Joshua I. Glaser, and Lee E. Miller. "Area 2 of primary somatosensory cortex encodes kinematics of the whole arm." Elife 9 (2020). - * Chowdhury, Raeed; Miller, Lee (2022) Area2 Bump: macaque somatosensory area 2 spiking activity during reaching with perturbations (Version 0.220113.0359) [Data set]. `DANDI archive `_ - * Pei, Felix, et al. "Neural Latents Benchmark'21: Evaluating latent variable models of neural population activity." arXiv preprint arXiv:2109.04463 (2021). - + * Chowdhury, Raeed H., Joshua I. Glaser, and Lee E. Miller. + "Area 2 of primary somatosensory cortex encodes kinematics of the whole arm." + Elife 9 (2020). + * Chowdhury, Raeed; Miller, Lee (2022) + Area2 Bump: macaque somatosensory area 2 spiking activity during reaching + with perturbations (Version 0.220113.0359) [Data set]. + `DANDI archive `_ + * Pei, Felix, et al. + "Neural Latents Benchmark'21: Evaluating latent variable models of neural + population activity." arXiv preprint arXiv:2109.04463 (2021). """ -import hashlib import pathlib -import pickle as pk from typing import Union import joblib as jl import numpy as np -import scipy.io import torch import cebra.data @@ -72,7 +75,7 @@ def _load_data( try: from nlb_tools.nwb_interface import NWBDataset - except ImportError as e: + except ImportError: raise ImportError( "Could not import the nlb_tools package required for data loading " "the raw reaching datasets in NWB format. " @@ -424,7 +427,7 @@ def _create_area2_dataset(): for session_type in ["active", "passive", "active-passive", "all"]: @register(f"area2-bump-pos-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV1(Area2BumpDataset): """Monkey reaching dataset with hand position labels. The dataset loads continuous x,y hand position as behavior labels. @@ -453,7 +456,7 @@ def continuous_index(self): return self.pos @register(f"area2-bump-target-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV2(Area2BumpDataset): """Monkey reaching dataset with target direction labels. The dataset loads discrete target direction (0-7) as behavior labels. @@ -480,7 +483,7 @@ def continuous_index(self): return None @register(f"area2-bump-posdir-{session_type}") - class Dataset(Area2BumpDataset): + class DatasetV3(Area2BumpDataset): """Monkey reaching dataset with hand position labels and discrete target labels. The dataset loads continuous x,y hand position and discrete target labels (0-7) @@ -523,7 +526,7 @@ def _create_area2_shuffled_dataset(): for session_type in ["active", "active-passive"]: @register(f"area2-bump-pos-{session_type}-shuffled-trial") - class Dataset(Area2BumpShuffledDataset): + class DatasetV4(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled trial type. The dataset loads the discrete binary trial type label active(0)/passive(1) @@ -551,7 +554,7 @@ def continuous_index(self): return self.pos @register(f"area2-bump-pos-{session_type}-shuffled-position") - class Dataset(Area2BumpShuffledDataset): + class DatasetV5(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled hand position. The dataset loads continuous x,y hand position in randomly shuffled order. @@ -580,7 +583,7 @@ def continuous_index(self): return self.pos_shuffled @register(f"area2-bump-target-{session_type}-shuffled") - class Dataset(Area2BumpShuffledDataset): + class DatasetV6(Area2BumpShuffledDataset): """Monkey reaching dataset with the shuffled hand position. The dataset loads discrete target direction (0-7 for active and 0-15 for active-passive) diff --git a/cebra/datasets/save_dataset.py b/cebra/datasets/save_dataset.py index f5e01a62..a9b3862a 100644 --- a/cebra/datasets/save_dataset.py +++ b/cebra/datasets/save_dataset.py @@ -50,7 +50,7 @@ def save_allen_decoding_dataset(savepath=get_datapath("allen_preload/")): print(f"{savepath}/{dataname}.jl") -def save_allen_dataset(savepath=get_datapth("allen_preload/")): +def save_allen_dataset(savepath=get_datapath("allen_preload/")): """Load and save complete allen dataset for Ca. Load and save all neural and behavioral data relevant for allen decoding dataset to reduce data loading time for the experiments using the shared data. It saves Ca data for the neuron numbers (10-1000), 5 different seeds for sampling the neurons. diff --git a/cebra/distributions/base.py b/cebra/distributions/base.py index 990d7e79..07ad9ae4 100644 --- a/cebra/distributions/base.py +++ b/cebra/distributions/base.py @@ -31,7 +31,6 @@ """ import abc -import functools import torch @@ -82,7 +81,7 @@ def to(self, device: str): self._generator = torch.Generator(device=device) try: self._generator.set_state(state.to(device)) - except (TypeError, RuntimeError) as e: + except (TypeError, RuntimeError): # TODO(https://discuss.pytorch.org/t/cuda-rng-state-does-not-change-when-re-seeding-why-is-that/47917/3) self._generator.manual_seed(self.seed) diff --git a/cebra/distributions/continuous.py b/cebra/distributions/continuous.py index c4235d48..ad95fdf6 100644 --- a/cebra/distributions/continuous.py +++ b/cebra/distributions/continuous.py @@ -23,7 +23,6 @@ from typing import Literal, Optional -import numpy as np import torch import cebra.data @@ -112,8 +111,8 @@ def __init__( abc_.HasGenerator.__init__(self, device=device, seed=seed) if continuous is None and num_samples is None: raise ValueError( - f"Supply either a continuous index (which will be used to infer the dataset size) " - f"or alternatively the number of datapoints using the num_samples argument." + "Supply either a continuous index (which will be used to infer the dataset size) " + "or alternatively the number of datapoints using the num_samples argument." ) if continuous is not None and num_samples is not None: if len(continuous) != num_samples: diff --git a/cebra/distributions/index.py b/cebra/distributions/index.py index 0ee0959a..724e86e4 100644 --- a/cebra/distributions/index.py +++ b/cebra/distributions/index.py @@ -30,7 +30,6 @@ discrete labels should be converted accordingly. """ -import numpy as np import torch import cebra.data @@ -188,9 +187,9 @@ def __init__(self, discrete, continuous): "of samples.") if len(discrete.shape) > 1: raise ValueError( - f"Discrete indexing information needs to be limited to a 1d " - f"array/tensor. Multi-dimensional discrete indices should be " - f"reformatted first.") + "Discrete indexing information needs to be limited to a 1d " + "array/tensor. Multi-dimensional discrete indices should be " + "reformatted first.") # TODO(stes): Once a helper function exists, the error message should # mention it. diff --git a/cebra/distributions/mixed.py b/cebra/distributions/mixed.py index 14fb8a61..7221fd99 100644 --- a/cebra/distributions/mixed.py +++ b/cebra/distributions/mixed.py @@ -27,7 +27,6 @@ """ from typing import Literal -import numpy as np import torch import cebra.io diff --git a/cebra/grid_search.py b/cebra/grid_search.py index 14337ac0..1805c896 100644 --- a/cebra/grid_search.py +++ b/cebra/grid_search.py @@ -138,7 +138,8 @@ def fit_models(self, to fit the CEBRA models on. The models are then trained using temporal contrastive learning (CEBRA-Time). An example of a valid ``datasets`` value could be: - ``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data), "dataset3": (neural_data2, continuous_data2)}``. + ``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data), + "dataset3": (neural_data2, continuous_data2)}``. params: Dict of parameter values provided by the user, either as a single value, for fixed hyperparameter values, or with a list of values for hyperparameters to optimize. If the value is a list of a single element, the hyperparameter is considered as fixed. diff --git a/cebra/helper.py b/cebra/helper.py index 8e9557e5..93bae2b1 100644 --- a/cebra/helper.py +++ b/cebra/helper.py @@ -165,7 +165,7 @@ def _requires_package_version(function): @wraps(function) def wrapper(*args, patched_version=None, **kwargs): - if patched_version != None: + if patched_version is not None: installed_version = pkg_resources.parse_version( patched_version) # Use the patched version if provided else: diff --git a/cebra/integrations/deeplabcut.py b/cebra/integrations/deeplabcut.py index c265b09a..4c5b292d 100644 --- a/cebra/integrations/deeplabcut.py +++ b/cebra/integrations/deeplabcut.py @@ -160,7 +160,7 @@ def load_data(self, pcutoff: float = 0.6) -> npt.NDArray: ) elif self.dlc_df.columns.nlevels == 4: raise NotImplementedError( - f"Multi-animals DLC files are not handled. Please provide a single-animal file." + "Multi-animals DLC files are not handled. Please provide a single-animal file." ) dlc_df_coords = ( diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py index b79deb66..30af7fd4 100644 --- a/cebra/integrations/matplotlib.py +++ b/cebra/integrations/matplotlib.py @@ -35,6 +35,7 @@ import torch from cebra import CEBRA +from cebra.helper import requires_package_version def _register_colormap(): @@ -289,13 +290,13 @@ def _define_plot_dim( * If ``idx_order`` is not provided, the plot will be 3D by default. * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if only 2 dimensions - are provided, the plot will be 2D. + are provided, the plot will be 2D. If the embedding dimension is equal to 2: * If ``idx_order`` is not provided, the plot will be 2D by default. * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if 2 dimensions - are provided, the plot will be 2D. + are provided, the plot will be 2D. This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). @@ -480,8 +481,9 @@ def plot(self, **kwargs) -> matplotlib.axes.Axes: elif isinstance(self.embedding_labels, Iterable): if len(self.embedding_labels) != self.embedding.shape[0]: raise ValueError( - f"Invalid embedding labels: the labels vector should have the same number of samples as the embedding, got {len(self.embedding_labels)}, expect {self.embedding.shape[0]}." - ) + f"Invalid embedding labels: the labels vector should have the same number " + f"of samples as the embedding, got {len(self.embedding_labels)}, " + f"expected {self.embedding.shape[0]}.") if self.embedding_labels.ndim > 1: raise NotImplementedError( f"Invalid embedding labels: plotting does not support multiple sets of labels, got {self.embedding_labels.ndim}." @@ -668,8 +670,9 @@ def _to_heatmap_format( assert len(pairs) == len(values), (self.pairs.shape, len(values)) score_dict = {tuple(pair): value for pair, value in zip(pairs, values)} - if self.labels is None: - n_grid = self.score + # NOTE(stes): Never used, might be possible to remove. + #if self.labels is None: + # n_grid = self.score heatmap_values = np.zeros((len(self.labels), len(self.labels))) @@ -1012,46 +1015,53 @@ def plot_embedding( If the embedding dimension is equal or higher to 3: - * If ``idx_order`` is not provided, the plot will be 3D by default. - * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if only 2 dimensions are provided, the plot will be 2D. + - If ``idx_order`` is not provided, the plot will be 3D by default. + - If ``idx_order`` is provided, and it has 3 dimensions, the plot will be 3D; + if only 2 dimensions are provided, the plot will be 2D. If the embedding dimension is equal to 2: - * If ``idx_order`` is not provided, the plot will be 2D by default. - * If ``idx_order`` is provided, if it has 3 dimensions, the plot will be 3D, if 2 dimensions are provided, the plot will be 2D. + - If ``idx_order`` is not provided, the plot will be 2D by default. + - If ``idx_order`` is provided, and it has 3 dimensions, the plot will be 3D; + if 2 dimensions are provided, the plot will be 2D. - This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of - dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). - - The function makes use of :py:func:`matplotlib.pyplot.scatter` and parameters from that function can be provided - as part of ``kwargs``. + This assumes that the dimensions provided to ``idx_order`` are within the range of the + number of dimensions of the embedding (i.e., between 0 and + :py:attr:`cebra.CEBRA.output_dimension` -1). + The function makes use of :py:func:`matplotlib.pyplot.scatter`, and parameters from + that function can be provided as part of ``kwargs``. Args: embedding: A matrix containing the feature representation computed with CEBRA. embedding_labels: The labels used to map the data to color. It can be: - * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. - * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. + - A vector that is the same sample size as the embedding, associating a value + to each sample, either discrete or continuous. + - A string, either `time`, which will color the embedding based on temporality, + or a string that can be interpreted as an RGB(A) color, which will display + the embedding uniformly with that color. + ax: Optional axis to create the plot on. - idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D - embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those - dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation - (e.g., (2, 4, 5)). + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension + in the 3D/2D embedding. The simplest form is (0, 1, 2) or (0, 1), but one might + want to plot either those dimensions differently (e.g., (1, 0, 2)) or other + dimensions from the feature representation (e.g., (2, 4, 5)). + markersize: The marker size. alpha: The marker blending, between 0 (transparent) and 1 (opaque). - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). title: The title on top of the embedding. figsize: Figure width and height in inches. dpi: Figure resolution. - kwargs: Optional arguments to customize the plots. See :py:func:`matplotlib.pyplot.scatter` documentation for more - details on which arguments to use. + kwargs: Optional arguments to customize the plots. See :py:func:`matplotlib.pyplot.scatter` + documentation for more details on which arguments to use. Returns: The axis :py:meth:`matplotlib.axes.Axes.axis` of the plot. Example: - >>> import cebra >>> import numpy as np >>> X = np.random.uniform(0, 1, (100, 50)) @@ -1061,8 +1071,8 @@ def plot_embedding( CEBRA(max_iterations=10) >>> embedding = cebra_model.transform(X) >>> ax = cebra.plot_embedding(embedding, embedding_labels='time') - """ + return _EmbeddingPlot( embedding=embedding, embedding_labels=embedding_labels, @@ -1134,7 +1144,12 @@ def plot_consistency( >>> labels2 = np.random.uniform(0, 1, (1000, )) >>> dataset_ids = ["achilles", "buddy"] >>> # between-datasets consistency, by aligning on the labels - >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2], labels=[labels1, labels2], dataset_ids=dataset_ids, between="datasets") + >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score( + ... embeddings=[embedding1, embedding2], + ... labels=[labels1, labels2], + ... dataset_ids=dataset_ids, + ... between="datasets" + ... ) >>> ax = cebra.plot_consistency(scores, pairs, datasets, vmin=0, vmax=100) """ @@ -1153,9 +1168,6 @@ def plot_consistency( ).plot(**kwargs) -from cebra.helper import requires_package_version - - @requires_package_version(matplotlib, "3.6") def compare_models( models: List[CEBRA], diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index e60bcb8f..bbaa1de6 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -79,7 +79,8 @@ def _define_colorscale(self, cmap: str): """Specify the cmap for plotting the latent space. Args: - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). Returns: @@ -171,39 +172,41 @@ def plot_embedding_interactive( The function makes use of :py:class:`plotly.graph_objects.Scatter` and parameters from that function can be provided as part of ``kwargs``. - Args: embedding: A matrix containing the feature representation computed with CEBRA. embedding_labels: The labels used to map the data to color. It can be: - * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. - * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. + - A vector that is the same sample size as the embedding, associating a value + to each of the sample, either discrete or continuous. + - A string, either `time`, then the labels will color the embedding based on + temporality, or a string that can be interpreted as a RGB(A) color, then + the embedding will be uniformly displayed with that unique color. + axis: Optional axis to create the plot on. - idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D - embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those - dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation - (e.g., (2, 4, 5)). + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension + in the 3D/2D embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want + to plot either those dimensions differently (e.g., (1, 0, 2)) or other dimensions + from the feature representation (e.g., (2, 4, 5)). + markersize: The marker size. alpha: The marker blending, between 0 (transparent) and 1 (opaque). - cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. + It will be ignored if `embedding_labels` is set to a valid RGB(A). title: The title on top of the embedding. figsize: Figure width and height in inches. dpi: Figure resolution. - kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments: - -- showlegend: Whether to show the legend or not. - -- discrete: Whether the labels are discrete or not. - -- col: The column of the subplot to plot the embedding on. - -- row: The row of the subplot to plot the embedding on. - -- template: The template to use for the plot. - - Note: showlegend can be True only if discrete is True. - - See :py:class:`plotly.graph_objects.Scatter` documentation for more - details on which arguments to use. - - Returns: - The plotly figure. - + kwargs: Optional arguments to customize the plots. This dictionary includes the following + optional arguments: + + - ``showlegend``: Whether to show the legend or not. + - ``discrete``: Whether the labels are discrete or not. + - ``col``: The column of the subplot to plot the embedding on. + - ``row``: The row of the subplot to plot the embedding on. + - ``template``: The template to use for the plot. + + Note: ``showlegend`` can be ``True`` only if ``discrete`` is ``True``. + See :py:class:`plotly.graph_objects.Scatter` documentation for more details on which + arguments to use. Example: diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index bf038237..046d3344 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -21,9 +21,7 @@ # """Define the CEBRA model.""" -import copy import itertools -import warnings from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union) @@ -33,7 +31,6 @@ import sklearn.utils.validation as sklearn_utils_validation import torch from sklearn.base import BaseEstimator -from sklearn.base import ClassifierMixin from sklearn.base import TransformerMixin from torch import nn @@ -274,8 +271,8 @@ def _require_arg(key): "Until then, please train using the PyTorch API.")) else: raise RuntimeError( - f"Index combination not covered. Please report this issue and add the following " - f"information to your bug report: \n" + error_message) + "Index combination not covered. Please report this issue and add the following " + "information to your bug report: \n" + error_message) def _check_type_checkpoint(checkpoint): @@ -317,7 +314,8 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": for key, value in state.items(): setattr(cebra_, key, value) - state_and_args = {**args, **state} + #TODO(stes): unused right now + #state_and_args = {**args, **state} if not sklearn_utils.check_fitted(cebra_): raise ValueError( @@ -776,20 +774,20 @@ def _configure_for_all( cebra.models.ConvolutionalModelMixin): if len(model[n].get_offset()) > 1: raise ValueError( - f"It is not yet supported to run non-convolutional models with " - f"receptive fields/offsets larger than 1 via the sklearn API. " - f"Please use a different model, or revert to the pytorch " - f"API for training.") + "It is not yet supported to run non-convolutional models with " + "receptive fields/offsets larger than 1 via the sklearn API. " + "Please use a different model, or revert to the pytorch " + "API for training.") d.configure_for(model[n]) else: if not isinstance(model, cebra.models.ConvolutionalModelMixin): if len(model.get_offset()) > 1: raise ValueError( - f"It is not yet supported to run non-convolutional models with " - f"receptive fields/offsets larger than 1 via the sklearn API. " - f"Please use a different model, or revert to the pytorch " - f"API for training.") + "It is not yet supported to run non-convolutional models with " + "receptive fields/offsets larger than 1 via the sklearn API. " + "Please use a different model, or revert to the pytorch " + "API for training.") dataset.configure_for(model) @@ -1338,13 +1336,13 @@ def save(self, - 'args': A dictionary of parameters used to initialize the CEBRA model. - 'state': The state of the CEBRA model, which includes various internal attributes. - 'state_dict': The state dictionary of the underlying solver used by CEBRA. - - 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn. + - 'metadata': Additional metadata about the saved model, including the backend used and + the version of CEBRA PyTorch, NumPy and scikit-learn. "torch" backend: The model is directly saved using `torch.save` with no additional information. The saved file contains the entire CEBRA model state. - Example: >>> import cebra @@ -1443,12 +1441,12 @@ def load(cls, if isinstance(checkpoint, dict) and backend == "torch": raise RuntimeError( - f"Cannot use 'torch' backend with a dictionary-based checkpoint. " - f"Please try a different backend.") + "Cannot use 'torch' backend with a dictionary-based checkpoint. " + "Please try a different backend.") if not isinstance(checkpoint, dict) and backend == "sklearn": raise RuntimeError( - f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. " - f"Please try a different backend.") + "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " + "Please try a different backend.") if backend == "sklearn": cebra_ = _load_cebra_with_sklearn_backend(checkpoint) diff --git a/cebra/integrations/sklearn/helpers.py b/cebra/integrations/sklearn/helpers.py index 06095c1e..2d2fc627 100644 --- a/cebra/integrations/sklearn/helpers.py +++ b/cebra/integrations/sklearn/helpers.py @@ -40,9 +40,9 @@ def _get_min_max( min = float("inf") max = float("-inf") for label in labels: - if any(isinstance(l, str) for l in label): + if any(isinstance(label_element, str) for label_element in label): raise ValueError( - f"Invalid labels dtype, expect floats or integers, got string") + "Invalid labels dtype, expect floats or integers, got string") min = np.min(label) if min > np.min(label) else min max = np.max(label) if max < np.max(label) else max return min, max diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 9712d021..ccecaa11 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -187,7 +187,7 @@ def _consistency_datasets( if labels is None: raise ValueError( "Missing labels, computing consistency between datasets requires labels, expect " - f"a set of labels for each embedding.") + "a set of labels for each embedding.") if len(embeddings) != len(labels): raise ValueError( "Invalid set of labels, computing consistency between datasets requires labels, " @@ -273,8 +273,8 @@ def _consistency_runs( if not all(embeddings[0].shape[0] == embeddings[i].shape[0] for i in range(1, len(embeddings))): raise ValueError( - f"Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way." - f"If your embeddings are coming from different models, you can use between-datasets" + "Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way." + "If your embeddings are coming from different models, you can use between-datasets" ) run_ids = np.arange(len(embeddings)) @@ -353,11 +353,11 @@ def consistency_score( if between == "runs": if labels is not None: raise ValueError( - f"No labels should be provided for between-runs consistency.") + "No labels should be provided for between-runs consistency.") if dataset_ids is not None: raise ValueError( - f"No dataset ID should be provided for between-runs consistency." - f"All embeddings should be computed on the same dataset.") + "No dataset ID should be provided for between-runs consistency." + "All embeddings should be computed on the same dataset.") scores, pairs, ids = _consistency_runs(embeddings=embeddings,) elif between == "datasets": scores, pairs, ids = _consistency_datasets( diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index 8dbdc2b4..47c2a87f 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -33,9 +33,10 @@ """ import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch +import torch.nn.functional as F from torch import nn @@ -212,7 +213,6 @@ def __init__(self, self.max_inverse_temperature = math.inf else: self.max_inverse_temperature = 1.0 / min_temperature - start_tempearture = float(temperature) log_inverse_temperature = torch.tensor( math.log(1.0 / float(temperature))) self.log_inverse_temperature = nn.Parameter(log_inverse_temperature) diff --git a/cebra/models/model.py b/cebra/models/model.py index f4a5d862..7631ba86 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -22,10 +22,8 @@ """Neural network models and criterions for training CEBRA models.""" import abc -import literate_dataclasses as dataclasses import torch import torch.nn.functional as F -import tqdm from torch import nn import cebra.data diff --git a/cebra/models/multiobjective.py b/cebra/models/multiobjective.py index da7f992e..d9393fdc 100644 --- a/cebra/models/multiobjective.py +++ b/cebra/models/multiobjective.py @@ -94,7 +94,7 @@ def is_valid(self, mode): Returns: ``True`` for a valid representation, ``False`` otherwise. """ - return mode in _ALL + return mode in _ALL # noqa: F821 def __init__( self, diff --git a/cebra/models/projector.py b/cebra/models/projector.py index 0c924296..dd7388bc 100644 --- a/cebra/models/projector.py +++ b/cebra/models/projector.py @@ -134,7 +134,7 @@ def features(self, inp, index): return self._features[index](inp) def forward(self, inp): - raise NotImplemented() + raise NotImplementedError() def get_offset(self) -> cebra.data.Offset: return cebra.data.Offset(5, 5) diff --git a/cebra/registry.py b/cebra/registry.py index be9afbd0..994fbd5c 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -115,7 +115,9 @@ def get_options( ): instance = cls.get_instance(module) if expand_parametrized: - filter_ = lambda k, v: True + + def filter_(k, v): + return True else: class _Filter(set): diff --git a/cebra/solver/base.py b/cebra/solver/base.py index c350ba35..e95151e5 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,11 +32,10 @@ import abc import os -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Callable, Dict, List, Literal, Optional import literate_dataclasses as dataclasses import torch -import tqdm import cebra import cebra.data @@ -204,7 +203,7 @@ def fit( validation_loss = self.validation(valid_loader) if self.best_loss is None or validation_loss < self.best_loss: self.best_loss = validation_loss - self.save(logdir, f"checkpoint_best.pth") + self.save(logdir, "checkpoint_best.pth") if save_model: if decode: self.decode_history.append( diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 8f456eb6..eabce729 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -21,11 +21,8 @@ # """Solver implementations for multi-session datasetes.""" -import abc -from collections.abc import Iterable from typing import List, Optional -import literate_dataclasses as dataclasses import torch import cebra diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index 6b3b1030..d172fadc 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -21,10 +21,7 @@ # """Single session solvers embed a single pair of time series.""" -import abc import copy -from collections.abc import Iterable -from typing import List import literate_dataclasses as dataclasses import torch diff --git a/cebra/solver/supervised.py b/cebra/solver/supervised.py index f69308e6..54a2da3a 100644 --- a/cebra/solver/supervised.py +++ b/cebra/solver/supervised.py @@ -25,17 +25,9 @@ It is inclear whether these will be kept. Consider the implementation as experimental/outdated, and the API for this particular package unstable. """ -import abc -from collections.abc import Iterable -from typing import List -import literate_dataclasses as dataclasses import torch -import tqdm -import cebra -import cebra.data -import cebra.models import cebra.solver.base as abc_ @@ -69,7 +61,7 @@ def fit(self, step_idx = 0 while True: for _, batch in enumerate(loader): - stats = self.step(batch) + _ = self.step(batch) self._log_checkpoint(num_steps, loader, valid_loader) step_idx += 1 if step_idx >= num_steps: From 51f048d288b49d2591a62e42ae3a079999c44afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Sun, 27 Oct 2024 10:36:13 -0300 Subject: [PATCH 3/4] Set default offset to an Offset object (#180) * Set default offset to an Offset object (fix #174) * Adapt default offset * Fix import issue caused by ruff Co-authored-by: Steffen Schneider --- cebra/data/datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index ecfc31ee..10c12223 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -30,6 +30,7 @@ import cebra.data as cebra_data import cebra.helper as cebra_helper +from cebra.data.datatypes import Offset class TensorDataset(cebra_data.SingleSessionDataset): @@ -65,7 +66,7 @@ def __init__(self, neural: Union[torch.Tensor, npt.NDArray], continuous: Union[torch.Tensor, npt.NDArray] = None, discrete: Union[torch.Tensor, npt.NDArray] = None, - offset: int = 1, + offset: Offset = Offset(0, 1), device: str = "cpu"): super().__init__(device=device) self.neural = self._to_tensor(neural, check_dtype="float").float() From e652b9a9423fe368474cfe4f5aa8fae12acde1bc Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 27 Oct 2024 15:21:46 +0100 Subject: [PATCH 4/4] Add additional tests for TensorDataset (#187) * Add additional tests for TensorDataset * Add explicit casting to avoid windows error --- cebra/data/datasets.py | 13 ++++- tests/test_datasets.py | 113 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 119 insertions(+), 7 deletions(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 10c12223..dbb2f1f5 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -71,7 +71,7 @@ def __init__(self, super().__init__(device=device) self.neural = self._to_tensor(neural, check_dtype="float").float() self.continuous = self._to_tensor(continuous, check_dtype="float") - self.discrete = self._to_tensor(discrete, check_dtype="integer") + self.discrete = self._to_tensor(discrete, check_dtype="int") if self.continuous is None and self.discrete is None: raise ValueError( "You have to pass at least one of the arguments 'continuous' or 'discrete'." @@ -87,7 +87,7 @@ def _to_tensor( Args: array: Array to check. - check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array` + check_dtype: If not `None`, list of dtypes to which the values in `array` must belong to. Defaults to None. Returns: @@ -98,11 +98,20 @@ def _to_tensor( if isinstance(array, np.ndarray): array = torch.from_numpy(array) if check_dtype is not None: + if check_dtype not in ["int", "float"]: + raise ValueError( + f"check_dtype must be 'int' or 'float', got {check_dtype}") if (check_dtype == "int" and not cebra_helper._is_integer(array) ) or (check_dtype == "float" and not cebra_helper._is_floating(array)): raise TypeError( f"Array has type {array.dtype} instead of {check_dtype}.") + if cebra_helper._is_floating(array): + array = array.float() + if cebra_helper._is_integer(array): + # NOTE(stes): Required for standardizing number format on + # windows machines. + array = array.long() return array @property diff --git a/tests/test_datasets.py b/tests/test_datasets.py index adbfab64..4bea0cf0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -68,9 +68,9 @@ def test_demo(): @pytest.mark.requires_dataset def test_hippocampus(): - from cebra.datasets import hippocampus - pytest.skip("Outdated") + + from cebra.datasets import hippocampus # noqa: F401 dataset = cebra.datasets.init("rat-hippocampus-single") loader = cebra.data.ContinuousDataLoader( dataset=dataset, @@ -99,7 +99,7 @@ def test_hippocampus(): @pytest.mark.requires_dataset def test_monkey(): - from cebra.datasets import monkey_reaching + from cebra.datasets import monkey_reaching # noqa: F401 dataset = cebra.datasets.init( "area2-bump-pos-active-passive", @@ -111,7 +111,7 @@ def test_monkey(): @pytest.mark.requires_dataset def test_allen(): - from cebra.datasets import allen + from cebra.datasets import allen # noqa: F401 pytest.skip("Test takes too long") @@ -148,7 +148,7 @@ def test_allen(): multisubject_options.extend( cebra.datasets.get_options( "rat-hippocampus-multisubjects-3fold-trial-split*")) -except: +except: # noqa: E722 options = [] @@ -388,3 +388,106 @@ def test_download_file_wrong_content_disposition(filename, url, expected_checksum=expected_checksum, location=temp_dir, file_name=filename) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_initialization(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + assert dataset.neural.shape == neural.shape + if continuous is not None: + assert dataset.continuous.shape == continuous.shape + if discrete is not None: + assert dataset.discrete.shape == discrete.shape + + +def test_tensor_dataset_invalid_initialization(): + neural = np.random.randn(100, 30) + with pytest.raises(ValueError): + cebra.data.datasets.TensorDataset(neural) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_length(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + assert len(dataset) == len(neural) + + +@pytest.mark.parametrize("neural, continuous, discrete", [ + (np.random.randn(100, 30), np.random.randn( + 100, 2), np.random.randint(0, 5, (100,))), + (np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))), + (np.random.randn(200, 40), np.random.randn(200, 5), None), +]) +def test_tensor_dataset_getitem(neural, continuous, discrete): + dataset = cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + index = torch.randint(0, len(dataset), (10,)) + batch = dataset[index] + assert batch.shape[0] == len(index) + assert batch.shape[1] == neural.shape[1] + + +def test_tensor_dataset_invalid_discrete_type(): + neural = np.random.randn(100, 30) + continuous = np.random.randn(100, 2) + discrete = np.random.randn(100, 2) # Invalid type: float instead of int + with pytest.raises(TypeError): + cebra.data.datasets.TensorDataset(neural, + continuous=continuous, + discrete=discrete) + + +@pytest.mark.parametrize("array, check_dtype, expected_dtype", [ + (np.random.randn(100, 30), "float", torch.float32), + (np.random.randint(0, 5, (100, 30)), "int", torch.int64), + (torch.randn(100, 30), "float", torch.float32), + (torch.randint(0, 5, (100, 30)), "int", torch.int64), + (None, None, None), +]) +def test_to_tensor(array, check_dtype, expected_dtype): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + result = dataset._to_tensor(array, check_dtype=check_dtype) + if array is None: + assert result is None + else: + assert isinstance(result, torch.Tensor) + assert result.dtype == expected_dtype + + +def test_to_tensor_invalid_dtype(): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + array = np.random.randn(100, 30) + with pytest.raises(TypeError): + dataset._to_tensor(array, check_dtype="int") + array = np.random.randint(0, 5, (100, 30)) + with pytest.raises(TypeError): + dataset._to_tensor(array, check_dtype="float") + + +def test_to_tensor_invalid_check_dtype(): + dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2), + continuous=np.random.randn( + 10, 2)) + array = np.random.randn(100, 30) + with pytest.raises(ValueError, + match="check_dtype must be 'int' or 'float', got"): + dataset._to_tensor(array, check_dtype="invalid_dtype")