Skip to content

Commit

Permalink
Remove the recipe state checkpointing *only* on intermediate paths an…
Browse files Browse the repository at this point in the history
…d get resume working w/ StatefulDataLoader
  • Loading branch information
joecummings committed Feb 14, 2025
1 parent 0c817b7 commit 4a3f1b1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 60 deletions.
54 changes: 21 additions & 33 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,12 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
import pdb

pdb.set_trace()
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch

# may have been already loaded if run is resumed
self.global_step = self.global_step or self.epochs_run * self._steps_per_epoch

# For now, default to saving at epoch boundaries
if self.save_every_n_steps is None:
self.save_every_n_steps = self._steps_per_epoch
Expand Down Expand Up @@ -598,31 +592,24 @@ def save_checkpoint(self, *, epoch: int, step: int) -> None:
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}

# If training is in-progress, checkpoint the optimizer state as well
is_intermediate = step < self._steps_per_epoch * self.total_epochs
if is_intermediate:
ckpt_dict.update(
{
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: epoch,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
CURR_STEP_KEY: step,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
DATALOADER_STATE_KEY: self._dataloader.state_dict(),
SAMPLER_STATE_KEY: None,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
ckpt_dict = {
training.MODEL_KEY: self._model.state_dict(),
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: epoch,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
CURR_STEP_KEY: step,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
DATALOADER_STATE_KEY: self._dataloader.state_dict(),
}
if not self._optimizer_in_bwd:
ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()

self.checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=is_intermediate,
intermediate_checkpoint=True,
step=step,
)

Expand Down Expand Up @@ -675,12 +662,14 @@ def train(self) -> None:
# # in case shuffle is True
# self._sampler.set_epoch(curr_epoch)

initial_step = self.global_step % self._steps_per_epoch
pbar = tqdm(
range(self.global_step % self._steps_per_epoch, self._steps_per_epoch),
initial=self.global_step % self._steps_per_epoch,
total=self._steps_per_epoch,
initial=initial_step,
)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: ?")
for idx, batch in enumerate(self._dataloader):
idx = idx + initial_step
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
Expand Down Expand Up @@ -767,9 +756,6 @@ def train(self) -> None:
and self.global_step % self.save_every_n_steps == 0
):
self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
import pdb

pdb.set_trace()

# Reset running stats for the next step
running_loss = 0
Expand All @@ -795,7 +781,9 @@ def train(self) -> None:

self.epochs_run += 1

self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
# Save final checkpoint if not already saved during training
if self.global_step % self.save_every_n_steps != 0:
self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
self._profiler.stop()

def cleanup(self) -> None:
Expand Down
45 changes: 18 additions & 27 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,11 @@ def save_checkpoint(
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)
logger.info(
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
"You need to merge the adapter weights into your base model for further use. "
f"See {self.__class__.__name__}.save_checkpoint for more details."
)

if self._model_type == ModelType.PHI3_MINI:
logger.warning(
Expand Down Expand Up @@ -918,33 +923,19 @@ def save_checkpoint(

# If the recipe state needs to be output, first remove the model state dict
# and if it exists, remove the adapter state dict as well
if intermediate_checkpoint:
_ = state_dict.pop(training.MODEL_KEY, None)
_ = state_dict.pop(training.ADAPTER_KEY, None)
_ = state_dict.pop(training.ADAPTER_CONFIG, None)
output_path = Path.joinpath(
self._output_dir, ckpt_save_dirname, "recipe_state.pt"
)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict, output_path)
logger.info(
"Recipe checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)
else:
logger.info("Saving final checkpoint.")
if adapter_only:
logger.info(
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
"You need to merge the adapter weights into your base model for further use. "
f"See {self.__class__.__name__}.save_checkpoint for more details."
)
else:
logger.info(
"The full model checkpoint, including all weights and configurations, has been saved successfully. "
"You can now use this checkpoint for further training or inference."
)
_ = state_dict.pop(training.MODEL_KEY, None)
_ = state_dict.pop(training.ADAPTER_KEY, None)
_ = state_dict.pop(training.ADAPTER_CONFIG, None)
output_path = Path.joinpath(
self._output_dir, ckpt_save_dirname, "recipe_state.pt"
)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict, output_path)
logger.info(
"Recipe checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)

# If specified, prune the checkpoints in the output directory
if self._keep_last_n_checkpoints is not None:
Expand Down

0 comments on commit 4a3f1b1

Please sign in to comment.