diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 61e8ec64c..bfbc916c8 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -92,6 +92,7 @@ make_sharded_optimizer_tensor, optim_state_to_sharding_state, ) + from megatron.core.dist_checkpointing.strategies import tensorstore from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer @@ -254,7 +255,7 @@ def configure_ddp(self): else: super().configure_ddp() - def optimizer_sharded_state_dict(self): + def optimizer_sharded_state_dict(self, unsharded_optim_state=None): """ Sharded state dictionary for an MainParamsOptimizerWrapper. Used to save and load the optimizer state when training with distributed_checkpoint. @@ -274,7 +275,7 @@ def optimizer_sharded_state_dict(self): } if isinstance(optimizer, MegatronDistributedFusedAdam): - return optimizer.sharded_state_dict(model_sharded_state_dict) + return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) elif not isinstance(optimizer, MainParamsOptimizerWrapper): # Regular optimizer, e.g. Adam or FusedAdam init_optimizer_states(optimizer) @@ -337,9 +338,14 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" # converts the optimizer states to their sharded equivalents - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) + checkpoint['optimizer_states'] = [sharded_optim_state] # dist_checkpointing expects a directory so we will name the directory # using the path with the file extension removed checkpoint_dir = ckpt_to_dir(filepath) @@ -437,9 +443,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: checkpoint['state_dict'] = sharded_state_dict checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path) - - checkpoint = self._fix_tensors_device(checkpoint) + if self.torch_dist_ckpt: + sharded_strategy = ('torch_dist', 1) + else: + sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) + checkpoint = dist_checkpointing.load( + sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=sharded_strategy + ) return checkpoint diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index a2316dabb..a85747c9f 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -549,8 +549,9 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) - def sharded_state_dict(self, model_sharded_state_dict): - optimizer_state_dict = self.state_dict() + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + if optimizer_state_dict is None: + optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),