diff --git a/train_scripts/train.py b/train_scripts/train.py index 72a3ec0..3a56931 100755 --- a/train_scripts/train.py +++ b/train_scripts/train.py @@ -947,10 +947,13 @@ def main(cfg: SanaConfig) -> None: if rng_state: logger.info("resuming randomise") torch.set_rng_state(rng_state["torch"]) - torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["python"]) generator.set_state(rng_state["generator"]) # resume generator status + try: + torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) + except: + logger.warning("Failed to resume torch_cuda rng state") # Prepare everything # There is no specific order to remember, you just need to unpack the