Skip to content

Commit

Permalink
better convergence checks
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Jan 30, 2025
1 parent b2e756c commit e59d6d0
Showing 1 changed file with 49 additions and 54 deletions.
103 changes: 49 additions & 54 deletions sbi/inference/trainers/npse/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,12 @@ def train(
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 200,
stop_after_epochs: int = 50,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
ema_loss_decay: float = 0.1,
validation_times: Union[Tensor, int] = 20,
resume_training: bool = False,
force_first_round_loss: bool = False,
discard_prior_samples: bool = False,
Expand All @@ -208,6 +209,10 @@ def train(
calibration_kernel: A function to calibrate the loss with respect
to the simulations `x` (optional). See Lueckmann, Gonçalves et al.,
NeurIPS 2017. If `None`, no calibration is used.
ema_loss_decay: Loss decay strength for exponential moving average of
training and validation losses.
validation_times: Diffusion times at which to evaluate the validation loss
to reduce variance of validation loss.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
Expand Down Expand Up @@ -294,6 +299,14 @@ def default_calibration_kernel(x):
# Move entire net to device for training.
self._neural_net.to(self._device)

if isinstance(validation_times, int):
validation_times = torch.linspace(
self._neural_net.t_min, self._neural_net.t_max, validation_times
)
assert isinstance(
validation_times, Tensor
) # let pyright know validation_times is a Tensor.

if not resume_training:
self.optimizer = Adam(list(self._neural_net.parameters()), lr=learning_rate)

Expand All @@ -316,11 +329,11 @@ def default_calibration_kernel(x):
)

train_losses = self._loss(
theta_batch,
x_batch,
masks_batch,
proposal,
calibration_kernel,
theta=theta_batch,
x=x_batch,
masks=masks_batch,
proposal=proposal,
calibration_kernel=calibration_kernel,
force_first_round_loss=force_first_round_loss,
)

Expand All @@ -345,12 +358,6 @@ def default_calibration_kernel(x):
# moving average of the training loss.
if len(self._summary["training_loss"]) == 0:
self._summary["training_loss"].append(train_loss_average)
else:
previous_loss = self._summary["training_loss"][-1]
self._summary["training_loss"].append(
(1.0 - ema_loss_decay) * previous_loss
+ ema_loss_decay * train_loss_average
)

# Calculate validation performance.
self._neural_net.eval()
Expand All @@ -363,20 +370,42 @@ def default_calibration_kernel(x):
batch[1].to(self._device),
batch[2].to(self._device),
)

# For validation loss, we evaluate at a fixed set of times to reduce
# the variance in the validation loss, for improved convergence
# checks. We evaluate the entire validation batch at all times, so
# we repeat the batches here to match.
val_batch_size = theta_batch.shape[0]
times_batch = validation_times.shape[0]
theta_batch = theta_batch.repeat(
times_batch, *([1] * (theta_batch.ndim - 1))
)
x_batch = x_batch.repeat(times_batch, *([1] * (x_batch.ndim - 1)))
masks_batch = masks_batch.repeat(
times_batch, *([1] * (masks_batch.ndim - 1))
)

validation_times_rep = validation_times.repeat_interleave(
val_batch_size, dim=0
)

# Take negative loss here to get validation log_prob.
val_losses = self._loss(
theta_batch,
x_batch,
masks_batch,
proposal,
calibration_kernel,
theta=theta_batch,
x=x_batch,
masks=masks_batch,
proposal=proposal,
calibration_kernel=calibration_kernel,
times=validation_times_rep,
force_first_round_loss=force_first_round_loss,
)

# print("val_losses: ", val_losses.shape)
val_loss_sum += val_losses.sum().item()

# Take mean over all validation samples.
val_loss = val_loss_sum / (
len(val_loader) * val_loader.batch_size # type: ignore
len(val_loader) * val_loader.batch_size * times_batch # type: ignore
)

# NOTE: Due to the inherently noisy nature we do instead log a exponential
Expand Down Expand Up @@ -489,6 +518,7 @@ def _loss(
masks: Tensor,
proposal: Optional[Any],
calibration_kernel: Callable,
times: Optional[Tensor] = None,
force_first_round_loss: bool = False,
) -> Tensor:
"""Return loss from score estimator. Currently only single-round NPSE
Expand All @@ -505,46 +535,11 @@ def _loss(
"""
if self._round == 0 or force_first_round_loss:
# First round loss.
loss = self._neural_net.loss(theta, x)
loss = self._neural_net.loss(theta, x, times)
else:
raise NotImplementedError(
"Multi-round NPSE with arbitrary proposals is not implemented"
)

assert_all_finite(loss, "NPSE loss")
return calibration_kernel(x) * loss

def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
"""Check if training has converged.
Unlike the `._converged` method in base.py, this method does not reset to the
best model. We noticed that this improves performance. Deleting this method
will make C2ST tests fail. This is because the loss is very stochastic, so
resetting might reset to an underfitted model. Ideally, we would write a
custom `._converged()` method which checks whether the loss is still going
down **for all t**.
Args:
epoch: Current epoch.
stop_after_epochs: Number of epochs to wait for improvement on the
validation set before terminating training.
Returns:
Whether training has converged.
"""
converged = False

# No checkpointing, just check if the validation loss has improved.

# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._epochs_since_last_improvement = 0
else:
self._epochs_since_last_improvement += 1

# If no validation improvement over many epochs, stop training.
if self._epochs_since_last_improvement > stop_after_epochs - 1:
converged = True

return converged

0 comments on commit e59d6d0

Please sign in to comment.