diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7802b787..993078fa 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -265,30 +265,69 @@ class MixedDataLoader(cebra_data.Loader): Sampling can be configured in different modes: - 1. Positive pairs always share their discrete variable. + 1. Positive pairs always share their discrete variable (positive_sampling = "discrete_variable"). 2. Positive pairs are drawn only based on their conditional, - not discrete variable. + not discrete variable (positive_sampling = "conditional"). + + When using the discrete variable, the prior distribution can either be uniform + (discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). + + Based on the selection of those parameters, the :py:class:`cebra.distributions.mixed.MixedTimeDeltaDistribution`, + :py:class:`cebra.distributions.discrete.DiscreteEmpirical`, or :py:class:`cebra.distributions.discrete.DiscreteUniform` + distributions are used for sampling. + + Args: + conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` + time_offset (int): :py:attr:`cebra.CEBRA.time_offsets` + positive_sampling (str): either "discrete_variable" (default) or "conditional" + discrete_sampling_prior (str): either "empirical" (default) or "uniform" """ conditional: str = dataclasses.field(default="time_delta") time_offset: int = dataclasses.field(default=10) + positive_sampling: str = dataclasses.field(default="discrete_variable") + discrete_sampling_prior: str = dataclasses.field(default="uniform") @property def dindex(self): - # TODO(stes) rename to discrete_index + warnings.warn("dindex is deprecated. Use discrete_index instead.", + DeprecationWarning) + return self.dataset.discrete_index + + @property + def discrete_index(self): return self.dataset.discrete_index @property def cindex(self): - # TODO(stes) rename to continuous_index + warnings.warn("cindex is deprecated. Use continuous_index instead.", + DeprecationWarning) + return self.dataset.continuous_index + + @property + def continuous_index(self): return self.dataset.continuous_index def __post_init__(self): super().__post_init__() - self.distribution = cebra.distributions.MixedTimeDeltaDistribution( - discrete=self.dindex, - continuous=self.cindex, - time_delta=self.time_offset) + if self.positive_sampling == "conditional": + self.distribution = cebra.distributions.MixedTimeDeltaDistribution( + discrete=self.discrete_index, + continuous=self.continuous_index, + time_delta=self.time_offset) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical": + self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform": + self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]: + raise ValueError( + f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but " + f"only accept 'uniform' or 'empirical' as potential values.") + else: + raise ValueError( + f"Invalid positive sampling mode: " + f"{self.positive_sampling} valid options are " + f"'conditional' or 'discrete_variable'.") def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. @@ -313,12 +352,23 @@ def get_indices(self, num_samples: int) -> BatchIndex: class. - Sample the negatives with matching discrete variable """ - reference_idx = self.distribution.sample_prior(num_samples) - return BatchIndex( - reference=reference_idx, - negative=self.distribution.sample_prior(num_samples), - positive=self.distribution.sample_conditional(reference_idx), - ) + if self.positive_sampling == "conditional": + reference_idx = self.distribution.sample_prior(num_samples) + return BatchIndex( + reference=reference_idx, + negative=self.distribution.sample_prior(num_samples), + positive=self.distribution.sample_conditional(reference_idx), + ) + else: + # taken from the DiscreteDataLoader get_indices function + reference_idx = self.distribution.sample_prior(num_samples * 2) + negative_idx = reference_idx[num_samples:] + reference_idx = reference_idx[:num_samples] + reference = self.discrete_index[reference_idx] + positive_idx = self.distribution.sample_conditional(reference) + return BatchIndex(reference=reference_idx, + positive=positive_idx, + negative=negative_idx) @dataclasses.dataclass diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..e2d819cf 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark): benchmark(load_speed) +@parametrize_device +@pytest.mark.parametrize( + "conditional, positive_sampling, discrete_sampling_prior", + [ + ("time", "discrete_variable", "empirical"), + ("time", "conditional", "empirical"), + ("time", "discrete_variable", "uniform"), + ("time", "conditional", "uniform"), + ("time_delta", "discrete_variable", "empirical"), + ("time_delta", "conditional", "empirical"), + ("time_delta", "discrete_variable", "uniform"), + ("time_delta", "conditional", "uniform"), + ], +) +def test_mixed( + conditional, positive_sampling, discrete_sampling_prior, device, benchmark +): + dataset = RandomDataset(N=100, d=5, device=device) + loader = cebra.data.MixedDataLoader( + dataset=dataset, + num_steps=10, + batch_size=8, + conditional=conditional, + positive_sampling=positive_sampling, + discrete_sampling_prior=discrete_sampling_prior, + ) + _assert_dataset_on_correct_device(loader, device) + load_speed = LoadSpeed(loader) + benchmark(load_speed) + + def _check_attributes(obj, is_list=False): if is_list: for obj_ in obj: