Skip to content

Commit

Permalink
Fix checkpoints.py by checking and allowing single-leaf targets.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 462786256
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 23, 2022
1 parent e4ef826 commit 6e71622
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
24 changes: 21 additions & 3 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def _checkpoint_path_step(path: str) -> Optional[float]:

def _split_gdas(
target: Dict[str, Any]) -> Tuple[Dict[str, Any], List[GlobalDeviceArray]]:
# When target is a single leaf instead of a pytree dict.
if not isinstance(target, (core.FrozenDict, dict)):
if isinstance(target, GlobalDeviceArray):
return GDA_PH, [target]
return target, []
# Traverse the target and handle GlobalDeviceArrays.
flattened = traverse_util.flatten_dict(target, keep_empty_nodes=True)
gda_targets = []
Expand Down Expand Up @@ -99,14 +104,27 @@ def _save_gdas(gda_manager: GlobalAsyncCheckpointManager,


def _restore_gdas(state_dict,
target: Optional[PyTree],
target: Optional[Any],
ckpt_path: str,
step: Optional[int] = None,
gda_manager: Optional[GlobalAsyncCheckpointManager] = None):

# When target is a single leaf instead of a pytree dict.
if not isinstance(state_dict, (core.FrozenDict, dict)):
if isinstance(target, GlobalDeviceArray) and isinstance(
state_dict, GlobalDeviceArray):
if not gda_manager:
raise errors.GDACheckpointingRequiredError(ckpt_path, step)
if not target:
raise errors.GDARestoreTargetRequiredError(ckpt_path, step)
gda_list = gda_manager.deserialize(
[target.mesh], [target.mesh_axes],
[get_tensorstore_spec(ckpt_path + '_gda')])
return gda_list[0]
return state_dict

# Check if a GDA is present in the restored pytree
flattened = traverse_util.flatten_dict(state_dict, keep_empty_nodes=True)

gda_paths = []
for key, value in flattened.items():
if isinstance(value, str) and value.startswith(GDA_PH):
Expand Down Expand Up @@ -355,7 +373,7 @@ def latest_checkpoint(ckpt_dir: Union[str, os.PathLike],

def restore_checkpoint(
ckpt_dir: Union[str, os.PathLike],
target: Optional[PyTree],
target: Optional[Any],
step: Optional[int] = None,
prefix: str = 'checkpoint_',
parallel: bool = True,
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ filterwarnings =
ignore:jax.tree_util.tree_multimap\(\) is deprecated.*:FutureWarning
# traverse_util.Traversal will be removed soon.
ignore:`flax.traverse_util.Traversal` will be deprecated.*:DeprecationWarning
# TODO: Will revisit all the deprecation warnings next week.
ignore:jax.tree_.*:FutureWarning

0 comments on commit 6e71622

Please sign in to comment.