diff --git a/sbi/neural_nets/estimators/base.py b/sbi/neural_nets/estimators/base.py index cc1438447..cf9390fec 100644 --- a/sbi/neural_nets/estimators/base.py +++ b/sbi/neural_nets/estimators/base.py @@ -119,15 +119,15 @@ def _check_input_shape(self, input: Tensor): class ConditionalDensityEstimator(ConditionalEstimator): r"""Base class for density estimators. - The density estimator class is a wrapper around neural networks that - allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ - pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`. + The density estimator class is a wrapper around neural networks that allows to + evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ pairs. Here + $\theta$ would be the `input` and $x$ would be the `condition`. Note: We assume that the input to the density estimator is a tensor of shape - (batch_size, input_size), where input_size is the dimensionality of the input. - The condition is a tensor of shape (batch_size, *condition_shape), where - condition_shape is the shape of the condition tensor. + (sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality + of the input. The condition is a tensor of shape (batch_size, *condition_shape), + where condition_shape is the shape of the condition tensor. """ @@ -226,15 +226,15 @@ def sample_and_log_prob( class ConditionalVectorFieldEstimator(ConditionalEstimator): r"""Base class for vector field (e.g., score and ODE flow) estimators. - The density estimator class is a wrapper around neural networks that - allows to evaluate the `vector_field`, and provide the `loss` of $\theta,x$ - pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`. + The vector field estimator class is a wrapper around neural networks that allows to + evaluate the `vector_field`, and provide the `loss` of $\theta,x$ pairs. Here + $\theta$ would be the `input` and $x$ would be the `condition`. Note: We assume that the input to the density estimator is a tensor of shape - (batch_size, input_size), where input_size is the dimensionality of the input. - The condition is a tensor of shape (batch_size, *condition_shape), where - condition_shape is the shape of the condition tensor. + (sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality + of the input. The condition is a tensor of shape (batch_dim, *condition_shape), + where condition_shape is the shape of the condition tensor. """ diff --git a/sbi/neural_nets/estimators/nflows_flow.py b/sbi/neural_nets/estimators/nflows_flow.py index 8edd9763b..04ba24196 100644 --- a/sbi/neural_nets/estimators/nflows_flow.py +++ b/sbi/neural_nets/estimators/nflows_flow.py @@ -81,7 +81,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: Args: input: Inputs to evaluate the log probability on. Of shape `(sample_dim, batch_dim, *event_shape)`. - condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`. + condition: Conditions of shape `(batch_dim, *event_shape)`. Raises: AssertionError: If `input_batch_dim != condition_batch_dim`. @@ -126,7 +126,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: Args: sample_shape: Shape of the samples to return. - condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`. + condition: Conditions of shape `(batch_dim, *event_shape)`. Returns: Samples of shape `(*sample_shape, condition_batch_dim)`. @@ -147,7 +147,7 @@ def sample_and_log_prob( Args: sample_shape: Shape of the samples to return. - condition: Conditions of shape (sample_dim, batch_dim, *event_shape). + condition: Conditions of shape (batch_dim, *event_shape). Returns: Samples of shape `(*sample_shape, condition_batch_dim, *input_event_shape)` diff --git a/sbi/neural_nets/estimators/zuko_flow.py b/sbi/neural_nets/estimators/zuko_flow.py index edc535d69..74b1bfc17 100644 --- a/sbi/neural_nets/estimators/zuko_flow.py +++ b/sbi/neural_nets/estimators/zuko_flow.py @@ -101,9 +101,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: Args: input: Inputs to evaluate the log probability on. Of shape `(sample_dim, batch_dim, *event_shape)`. - # TODO: the docstring is not correct here. in the code it seems we - do not have a sample_dim for the condition. - condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`. + condition: Conditions of shape `(batch_dim, *event_shape)`. Raises: AssertionError: If `input_batch_dim != condition_batch_dim`.