Skip to content

Commit

Permalink
merge with main and update
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardhan33 committed Oct 15, 2024
2 parents 254e2e5 + 5d1e7a3 commit d32d207
Show file tree
Hide file tree
Showing 162 changed files with 7,908 additions and 2,461 deletions.
4 changes: 2 additions & 2 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Changes in this file should match with requiredReviewers in .github/workflows/AddLabel.yml
* @gobbleturk
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
* @gobbleturk @jonb377 @khatwanimohit @bvandermoon @vipannalla
8 changes: 7 additions & 1 deletion .github/workflows/AddLabel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ jobs:
}
// This list should match with CODEOWNERS
let requiredReviewers = { gobbleturk: "" }
let requiredReviewers = {
gobbleturk: "",
jonb377: "",
khatwanimohit: "",
bvandermoon: "",
vipannalla: "",
}
const reviews = await github.rest.pulls.listReviews({
owner,
repo,
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/CPUTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@ jobs:
pylint pedagogical_examples/ && \
echo 'PyLint check on pedagogical_examples/ is successful' || { echo \
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
- name: Analysing the code with pyink in Maxtext/
run: |
pyink MaxText --check --diff --color --pyink-indentation=2 --line-length=125
- name: Analysing the code with pyink in pedagogical_examples/
run: |
pyink pedagogical_examples --check --diff --color --pyink-indentation=2 --line-length=125
8 changes: 8 additions & 0 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"]
steps:
- uses: actions/checkout@v3
- name: Cleanup old docker images
run: docker system prune --all --force
- name: build jax stable image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable MODE=stable DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable
Expand All @@ -39,7 +41,11 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly
- name: build jax stable stack image
run : |
<<<<<<< HEAD
bash docker_maxtext_jax_stable_stack_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxtext-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true
=======
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
>>>>>>> main
gpu:
strategy:
fail-fast: false
Expand All @@ -48,6 +54,8 @@ jobs:
runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"]
steps:
- uses: actions/checkout@v3
- name: Cleanup old docker images
run: docker system prune --all --force
- name: build jax stable image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable MODE=stable DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_stable
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_and_upload_images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] || [[ !
exit 1
fi

gcloud auth configure-docker us-docker.pkg.dev --quiet
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
gcloud auth configure-docker --quiet
image_date=$(date +%Y-%m-%d)

# Upload only dependencies image
Expand Down
2 changes: 1 addition & 1 deletion MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class SystemCharacteristics:
# across hosts will occur over DCN. This makes the "slice" topology of A3 fixed to a single host.
# To use AoT compilation with multihost, the `compile_topology_num_slices` flag should be
# specified to the number of hosts.
"a3": SystemCharacteristics("gpu", None, None, None, 8, None)
"a3": SystemCharacteristics("gpu", None, None, None, 8, None),
}


Expand Down
137 changes: 66 additions & 71 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,30 @@

"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Optional, Union
from typing import Any, Optional, Union
from absl import flags
from etils import epath
import orbax.checkpoint
from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger
from orbax.checkpoint import pytree_checkpoint_handler
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, PyTree
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import jax
import numpy as np
from flax.training import orbax_utils, train_state
import grain.python as grain

import jax
import max_logging
from multihost_dataloading import MultiHostDataLoadIterator
from flax.training import orbax_utils, train_state
import numpy as np
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager

PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler
# pylint: disable=too-many-positional-arguments

CheckpointManager = ocp.CheckpointManager
CheckpointManagerOptions = ocp.CheckpointManagerOptions
PyTreeCheckpointHandler = ocp.PyTreeCheckpointHandler
LocalCheckpointOptions = emergency_checkpoint_manager.LocalCheckpointOptions
PersistentCheckpointOptions = (
emergency_checkpoint_manager.PersistentCheckpointOptions
)
PersistentCheckpointOptions = emergency_checkpoint_manager.PersistentCheckpointOptions

abstract_logger = ocp.logging.abstract_logger
cloud_logger = ocp.logging.cloud_logger
composite_logger = ocp.logging.composite_logger
standard_logger = ocp.logging.standard_logger


def create_orbax_checkpoint_manager(
Expand All @@ -46,6 +49,8 @@ def create_orbax_checkpoint_manager(
save_interval_steps: int,
dataset_type: Optional[str] = "tfds",
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
use_ocdbt: bool = True,
use_zarr3: bool = True,
max_to_keep: int = None,
enable_background_delete: bool = False,
):
Expand All @@ -62,17 +67,24 @@ def create_orbax_checkpoint_manager(
item_names = ("items",)
if max_to_keep < 0:
max_to_keep = None

# local storage checkpoint needs parent directory created
p.mkdir(exist_ok=True, parents=True)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
# omitting `iter` uses default handler for `iter`
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
mngr = CheckpointManager(
p,
item_names=item_names,
item_handlers=item_handlers,
options=CheckpointManagerOptions(
create=True,
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
max_to_keep=max_to_keep,
enable_background_delete=enable_background_delete,
),
logger=orbax_logger
logger=orbax_logger,
)
max_logging.log("Checkpoint manager created!")
return mngr
Expand All @@ -82,7 +94,7 @@ def create_orbax_emergency_checkpoint_manager(
local_checkpoint_dir: str,
persistent_checkpoint_dir: str,
global_mesh: jax.sharding.Mesh,
abstract_state: PyTree,
abstract_state: Any,
local_save_interval_steps: int,
persistent_save_interval_steps: int,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
Expand All @@ -92,12 +104,8 @@ def create_orbax_emergency_checkpoint_manager(
max_logging.log("Creating emergency checkpoint manager...")

options = emergency_checkpoint_manager.CheckpointManagerOptions(
local=LocalCheckpointOptions(
save_interval_steps=local_save_interval_steps
),
persistent=PersistentCheckpointOptions(
save_interval_steps=persistent_save_interval_steps
),
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
)
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
Expand Down Expand Up @@ -183,53 +191,49 @@ def map_to_pspec(data):
pspec = data.sharding.spec
mesh = data.sharding.mesh
if not enable_single_replica_ckpt_restoring:
return orbax.checkpoint.type_handlers.ArrayRestoreArgs(
mesh=mesh, mesh_axes=pspec)
return ocp.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(
replica_mesh, pspec)

array_handler = (
orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000 # 1000 MB limit
)
)
orbax.checkpoint.type_handlers.register_type_handler(
jax.Array,
array_handler,
override=True
)

return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs(
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
global_shape=data.shape,
dtype=data.dtype,
)

if enable_single_replica_ckpt_restoring:
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(
map_to_pspec,
abstract_unboxed_pre_state,
)

if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
return (
checkpoint_manager.restore(
latest_step,
args=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args),
),
None,
checkpoint_manager.restore(
latest_step,
args=ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args),
),
None,
)

if dataset_type == "grain" and data_iterator is not None:
return (
checkpoint_manager.restore(
latest_step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args),
args=ocp.args.Composite(
items=ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
restore_args=restore_args,
),
iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator),
),
),
Expand All @@ -239,8 +243,11 @@ def map_to_pspec(data):
return (
checkpoint_manager.restore(
latest_step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
args=ocp.args.Composite(
items=ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
restore_args=restore_args,
)
),
),
None,
Expand All @@ -252,7 +259,7 @@ def map_to_pspec(data):
elif load_full_state_from_path != "":
max_logging.log(f"restoring full state from {load_full_state_from_path=}")
p = epath.Path(load_full_state_from_path)
ckptr = orbax.checkpoint.StandardCheckpointer()
ckptr = ocp.StandardCheckpointer()
restored = ckptr.restore(p, abstract_unboxed_pre_state)
return {"items": restored}, None

Expand All @@ -273,9 +280,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None:
max_logging.log("Setting up checkpoint logger...")
if config.enable_checkpoint_cloud_logger:
logger_name = f"checkpoint_{config.run_name}"
options = cloud_logger.CloudLoggerOptions(
job_name=config.run_name, logger_name=logger_name
)
options = cloud_logger.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name)
orbax_cloud_logger = cloud_logger.CloudLogger(options=options)
max_logging.log("Successfully set up checkpoint cloud logger.")

Expand All @@ -285,9 +290,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None:

orbax_logger = None
if orbax_cloud_logger is not None and orbax_standard_logger is not None:
orbax_logger = composite_logger.CompositeLogger(
orbax_cloud_logger, orbax_standard_logger
)
orbax_logger = composite_logger.CompositeLogger(orbax_cloud_logger, orbax_standard_logger)
max_logging.log("Successfully set up checkpoint composite logger.")

return orbax_logger
Expand All @@ -298,30 +301,22 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params):
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")
ckpt = epath.Path(load_parameters_from_path)
ckptr = orbax.checkpoint.PyTreeCheckpointer()
ckptr = ocp.PyTreeCheckpointer()
# This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
# memory, we instead specify here that we are just restoring the params field of the checkpoint
# (which itself may be a dictionary containing a key named 'params').
restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_params)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params)
restored = ckptr.restore(
ckpt,
item={"params": abstract_unboxed_params},
transforms={},
restore_args={"params": restore_args}
)
ckpt, item={"params": abstract_unboxed_params}, transforms={}, restore_args={"params": restore_args}
)
return restored["params"]


def save_params_to_path(checkpoint_dir, params):
"""Save decode params in checkpoint at specified path."""
assert checkpoint_dir, "checkpoint_dir is not defined."
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target({"params":params})
orbax_checkpointer.save(
checkpoint_dir,
{"params":params},
save_args=save_args,
force=True
)
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target({"params": params})
orbax_checkpointer.save(checkpoint_dir, {"params": params}, save_args=save_args, force=True)
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")
5 changes: 5 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flax.linen import partitioning
import jax
import jax.numpy as jnp
import numpy as np

Config = Any

Expand Down Expand Up @@ -55,3 +56,7 @@
MODEL_MODE_TRAIN = "train"

DECODING_ACTIVE_SEQUENCE_INDICATOR = 1

# A large negative mask value is used for masking to ensure that the
# softmax function assigns an extremely low probability to the masked positions.
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
Loading

0 comments on commit d32d207

Please sign in to comment.