diff --git a/train.py b/train.py index 7ab2043..f32795a 100644 --- a/train.py +++ b/train.py @@ -385,22 +385,6 @@ def train(self): for epoch in range(self.n_epochs): self.epoch = epoch self.train_one_epoch() - # Save after each epoch - print('Epoch completed. Saving..') - state = { - 'net': {key: self.model[key].state_dict() for key in self.model}, - 'optimizer': self.optimizer.state_dict(), - 'scheduler': self.optimizer.scheduler_state_dict(), - 'iters': self.iters, - 'epoch': self.epoch, - } - save_path = os.path.join( - self.log_dir, - f'DiT_epoch_{self.epoch:05d}_step_{self.iters:05d}.pth' - ) - torch.save(state, save_path) - print(f"Checkpoint saved at {save_path}") - if self.iters >= self.max_steps: break