Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
MikhailKardash committed Oct 23, 2024
1 parent aae47d2 commit 9bee1fa
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions harness/determined/pytorch/deepspeed/_deepspeed_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,9 +1032,20 @@ def _load(self, load_path: pathlib.Path) -> None:
)

save_path = load_path.joinpath("trial_state.pkl")

try:
assert self.state
except AssertionError as e:

if save_path.exists():
with save_path.open("rb") as f:
self._load_state(pickle.load(f))
else:
# Support legacy save states.
wlsq_path = load_path.joinpath("workload_sequencer.pkl")
if wlsq_path.exists():
with wlsq_path.open("rb") as f:
self._load_wlsq_state(pickle.load(f))

def _load_state(self, state: Any) -> None:
# Load our state from the checkpoint if we are continuing training after a pause or restart.
Expand All @@ -1055,6 +1066,26 @@ def _load_state(self, state: Any) -> None:
if self.state.batches_trained == self.val_from_previous_run:
self.state.last_val = self.state.batches_trained

def _load_wlsq_state(self, state: Any) -> None:
if state.get("trial_id") != self.trial_id:
self.state = pytorch._TrialState(trial_id=self.trial_id)
return

self.state = pytorch._TrialState(
trial_id=state.get("trial_id"),
last_ckpt=state.get("last_ckpt"),
last_val=state.get("last_val"),
step_id=state.get("step_id"),
# steps_completed is a legacy field kept to support loading from older checkpoints.
# checkpoints should only persist batches_trained and epochs_trained
batches_trained=state.get("steps_completed"),
epochs_trained=self._get_epoch_idx(state.get("steps_completed")),
)

assert self.state
if self.state.batches_trained == self.val_from_previous_run:
self.state.last_val = self.state.batches_trained

def _save(self, path: pathlib.Path) -> None:
path.mkdir(parents=True, exist_ok=True)

Expand Down

0 comments on commit 9bee1fa

Please sign in to comment.