From 470dd6dcebf712d620b95c9112fafb19dc94ba44 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 23 Jan 2025 11:42:06 -0500 Subject: [PATCH] Precisely track nvme optimizer offload (#6963) Fix #4998 --- deepspeed/runtime/engine.py | 10 ++++------ deepspeed/runtime/swap_tensor/optimizer_utils.py | 5 +++++ deepspeed/runtime/zero/stage3.py | 12 ++++-------- tests/unit/runtime/zero/test_nvme_checkpointing.py | 4 +++- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9b9a2e509d61..97d2afb8b723 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -799,10 +799,8 @@ def zero_load_from_fp32_weights(self): def zero_elastic_checkpoint(self): return self._config.zero_config.elastic_checkpoint - def zero_has_nvme_offload(self): - if not hasattr(self.optimizer, "swap_optimizer"): - return False - return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu + def zero_nvme_offload_optimizer(self): + return getattr(self.optimizer, "swap_optimizer", False) def zero_max_live_parameters(self): return self._config.zero_config.max_live_parameters @@ -2865,7 +2863,7 @@ def load_checkpoint(self, if not success: self.optimizer._restore_from_bit16_weights() - if self.zero_has_nvme_offload(): + if self.zero_nvme_offload_optimizer(): from shutil import copytree, disk_usage offload_dir = self.optimizer.optimizer_swapper.swap_folder offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors") @@ -3205,7 +3203,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, self._create_zero_checkpoint_files(save_dir, tag) self._save_zero_checkpoint(save_dir, tag) - if self.zero_has_nvme_offload(): + if self.zero_nvme_offload_optimizer(): from shutil import copytree, disk_usage offload_dir = self.optimizer.optimizer_swapper.swap_folder offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors") diff --git a/deepspeed/runtime/swap_tensor/optimizer_utils.py b/deepspeed/runtime/swap_tensor/optimizer_utils.py index 389ad6ae1076..5d837e386a95 100644 --- a/deepspeed/runtime/swap_tensor/optimizer_utils.py +++ b/deepspeed/runtime/swap_tensor/optimizer_utils.py @@ -153,6 +153,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume 'timer_names', ] + def purge_state(self): + for swap_info in self.swap_params_info.values(): + swap_info.tensors = [swap_info.tensors[0]] + swap_info.has_state_tensors = False + def swappable_tensor(self, param=None, numel=None): assert param is not None or numel is not None, "Either param or numel must be provided" if param is not None: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a5c0c3340019..3195c973a179 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2652,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) self._clear_fp32_optimizer_param_groups() - if self.swap_optimizer or self.params_in_nvme_and_cpu: + if self.swap_optimizer: # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint - for swap_info in self.optimizer_swapper.swap_params_info.values(): - swap_info.tensors = [swap_info.tensors[0]] - swap_info.has_state_tensors = False + self.optimizer_swapper.purge_state() if self.swap_optimizer: # Touch all parameters to synchronize all buffers @@ -2773,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa else: optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor - if self.swap_optimizer or self.params_in_nvme_and_cpu: + if self.swap_optimizer: # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint - for swap_info in self.optimizer_swapper.swap_params_info.values(): - swap_info.tensors = [swap_info.tensors[0]] - swap_info.has_state_tensors = False + self.optimizer_swapper.purge_state() if self.swap_optimizer: # Touch all parameters to synchronize all buffers diff --git a/tests/unit/runtime/zero/test_nvme_checkpointing.py b/tests/unit/runtime/zero/test_nvme_checkpointing.py index 850c8eb3e349..01a75aa64b4e 100644 --- a/tests/unit/runtime/zero/test_nvme_checkpointing.py +++ b/tests/unit/runtime/zero/test_nvme_checkpointing.py @@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest): world_size = 1 @pytest.mark.parametrize('param_offload_device, optim_offload_device', - [(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu), + [(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme), (OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme), + (OffloadDeviceEnum.nvme, OffloadDeviceEnum.none), + (OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu), (OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)]) def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device): zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")