diff --git a/efg/engine/trainer.py b/efg/engine/trainer.py index d4dc691..66b7cb3 100644 --- a/efg/engine/trainer.py +++ b/efg/engine/trainer.py @@ -219,7 +219,7 @@ def resume_or_load(self, resume=False): all_model_checkpoints = sorted(all_model_checkpoints, key=os.path.getmtime) if len(all_model_checkpoints) > 0: - if self.config.model.weights is not None: + if self.config.model.weights not in ("", None): matched = np.nonzero( np.array([pts.endswith(self.config.model.weights.split("/")[-1]) for pts in all_model_checkpoints]) )[0]