From 5721909f068e88f7ba3e68a1477bc8f35f3a30cc Mon Sep 17 00:00:00 2001 From: juacrumar Date: Fri, 26 Jul 2024 21:40:39 +0200 Subject: [PATCH] fix bug in stopping --- n3fit/src/n3fit/stopping.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/n3fit/src/n3fit/stopping.py b/n3fit/src/n3fit/stopping.py index e2f11a2579..48194cdf1b 100644 --- a/n3fit/src/n3fit/stopping.py +++ b/n3fit/src/n3fit/stopping.py @@ -27,6 +27,7 @@ which will tell `Validation` that no validation set was found and that the training is to be used instead. """ + import logging import numpy as np @@ -345,6 +346,8 @@ def __init__( self._threshold_chi2 = threshold_chi2 self._stopping_degrees = np.zeros(self._n_replicas, dtype=int) self._counts = np.zeros(self._n_replicas, dtype=int) + # Keep track of the replicas that should not be stopped yet + self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool) self._dont_stop = dont_stop self._stop_now = False @@ -451,6 +454,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False): passes &= fitstate.vl_loss < self._best_val_chi2s # And the ones that pass positivity passes &= self._positivity(fitstate) + # Stop replicas that are ok being stopped (because they are finished or otherwise) + passes &= self._dont_stop_me_now self._stopping_degrees += self._counts @@ -470,6 +475,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False): for i_replica in np.where(stop_replicas)[0]: self._stop_epochs[i_replica] = epoch self._counts[i_replica] = 0 + self._dont_stop_me_now[i_replica] = False # By using the stopping degree we only stop when none of the replicas are improving anymore if min(self._stopping_degrees) > self.stopping_patience: