Skip to content

Commit

Permalink
Merge branch 'master' into loadams/update-a6000-workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 24, 2025
2 parents bfd41f8 + 470dd6d commit 5e5b168
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
10 changes: 4 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,8 @@ def zero_load_from_fp32_weights(self):
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint

def zero_has_nvme_offload(self):
if not hasattr(self.optimizer, "swap_optimizer"):
return False
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
def zero_nvme_offload_optimizer(self):
return getattr(self.optimizer, "swap_optimizer", False)

def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
Expand Down Expand Up @@ -2865,7 +2863,7 @@ def load_checkpoint(self,
if not success:
self.optimizer._restore_from_bit16_weights()

if self.zero_has_nvme_offload():
if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
Expand Down Expand Up @@ -3205,7 +3203,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True,
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)

if self.zero_has_nvme_offload():
if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/swap_tensor/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume
'timer_names',
]

def purge_state(self):
for swap_info in self.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False

def swappable_tensor(self, param=None, numel=None):
assert param is not None or numel is not None, "Either param or numel must be provided"
if param is not None:
Expand Down
12 changes: 4 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
self._clear_fp32_optimizer_param_groups()

if self.swap_optimizer or self.params_in_nvme_and_cpu:
if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
Expand Down Expand Up @@ -2773,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
else:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor

if self.swap_optimizer or self.params_in_nvme_and_cpu:
if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/runtime/zero/test_nvme_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
world_size = 1

@pytest.mark.parametrize('param_offload_device, optim_offload_device',
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu),
[(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")
Expand Down

0 comments on commit 5e5b168

Please sign in to comment.