diff --git a/axlearn/common/factorized_rms_test.py b/axlearn/common/factorized_rms_test.py index 931c16c6..94d47168 100644 --- a/axlearn/common/factorized_rms_test.py +++ b/axlearn/common/factorized_rms_test.py @@ -12,11 +12,12 @@ from axlearn.common import factorized_rms from axlearn.common.base_layer import FactorizationSpec, ParameterSpec from axlearn.common.optimizer_base import ( - NestedOptStateSpec, + Nested, OptParam, + OptStateSpec, PartitionedGradientTransformation, ) -from axlearn.common.optimizers import OptStateSpec, with_partition_fn +from axlearn.common.optimizers import with_partition_fn from axlearn.common.test_utils import TestCase from axlearn.common.utils import PartitionSpec, flatten_items @@ -59,7 +60,7 @@ def testParity(self, factored, dtype): # The 'exp' optimizer is partitioned according to the mesh_axes of parameters and # factorization spec. - exp_partition: NestedOptStateSpec = exp.partition(param_specs) + exp_partition: Nested[OptStateSpec] = exp.partition(param_specs) # Used for `count`. count_spec = OptStateSpec( dtype=jnp.int32, diff --git a/axlearn/common/optimizer_base.py b/axlearn/common/optimizer_base.py index 8c156bb7..25c1163d 100644 --- a/axlearn/common/optimizer_base.py +++ b/axlearn/common/optimizer_base.py @@ -16,14 +16,13 @@ - weight_decay_scale: control the weight decay rate. """ import dataclasses -from collections.abc import Sequence from typing import Any, Callable, NamedTuple, Optional, Union import optax import typing_extensions -from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec -from axlearn.common.utils import Tensor, TensorSpec +from axlearn.common.base_layer import FactorizationSpec, ParameterSpec +from axlearn.common.utils import Nested, Tensor, TensorSpec @dataclasses.dataclass @@ -66,8 +65,7 @@ def __call__( # Specification of an optimizer state array. OptStateSpec = TensorSpec -NestedOptStateSpec = Union[OptStateSpec, dict, Sequence] -TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec] +TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]] class PartitionedGradientTransformation(NamedTuple): diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 70517ad5..8c2d7be4 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -36,10 +36,11 @@ import typing_extensions from absl import logging from jax import numpy as jnp +from jax._src.sharding_impls import TransferToMemoryKind from optax._src import numerics from axlearn.common import schedule, struct -from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec +from axlearn.common.base_layer import ParameterSpec, PartitionSpec from axlearn.common.config import ConfigOr, maybe_instantiate from axlearn.common.factorized_rms import scale_by_factored_rms from axlearn.common.module import current_context @@ -51,8 +52,8 @@ TransformPartitionSpecFn, ) from axlearn.common.utils import ( + MemoryKind, Nested, - NestedPartitionSpec, NestedTensor, NestedTree, Tensor, @@ -139,19 +140,41 @@ def update_fn( return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) -def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec: +def copy_partition( + specs: Nested[OptStateSpec], + *, + pattern: Union[None, str, re.Pattern] = None, + memory_kind: Optional[MemoryKind] = None, +) -> Nested[OptStateSpec]: + """Copies OptStateSpec and optionally assigns with a different memory kind. + + Args: + specs: Nested[OptStateSpec] to copy from. + pattern: Regex to match the full path of each spec. Matched specs will have their memory + kind replaced with `memory_kind`. + memory_kind: New memory kind. Default to None. + + Returns: + A Nested[OptStateSpec] with possibly a different memory kind. + """ return jax.tree.map( - lambda param_spec: OptStateSpec( - dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes + lambda path, spec: OptStateSpec( + dtype=spec.dtype, + shape=spec.shape, + mesh_axes=spec.mesh_axes, + memory_kind=memory_kind + if pattern and re.fullmatch(pattern, path) + else spec.memory_kind, ), - param_specs, + tree_paths(specs), + specs, ) def trace_partition( base: optax.GradientTransformation, ) -> PartitionedGradientTransformation: - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]: return optax.TraceState(trace=copy_partition(param_specs)) return with_partition_fn(base, partition_fn) @@ -160,7 +183,9 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation: state: optax.ScaleByAdamState = base.init({}) - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn( + param_specs: Nested[ParameterSpec], + ) -> Nested[Union[OptStateSpec, optax.ScaleByAdamState]]: return optax.ScaleByAdamState( count=OptStateSpec( dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec() @@ -950,7 +975,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U ) return updates, new_state - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[Union[OptStateSpec, EmaState]]: def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec: # Store momentum in accumulator_dtype if it is set and p is not scalar. if param_spec.shape and accumulator_dtype is not None: @@ -1412,7 +1437,9 @@ def _is_valid_step( drop_stats=new_drop_stats, ) - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn( + param_specs: Nested[ParameterSpec], + ) -> Nested[Union[OptStateSpec, SkipClipState]]: if use_adaptive_drop_norm: one = jnp.ones([], jnp.float32) dict_thresholds = drop_norm(count=one, mean=one, stddev=one) @@ -1571,7 +1598,9 @@ def update_fn(updates, state, params): ) return updates, ParamEmaState(count=count_inc, ema=new_ema) - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn( + param_specs: Nested[ParameterSpec], + ) -> Nested[Union[OptStateSpec, ParamEmaState]]: return ParamEmaState( count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()), ema=copy_partition(param_specs), @@ -1617,7 +1646,9 @@ def update_fn(updates, state, params=None): updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu) return updates, ScaleByLionState(count=count_inc, mu=mu) - def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec: + def partition_fn( + param_specs: Nested[ParameterSpec], + ) -> Nested[Union[OptStateSpec, ScaleByLionState]]: mu_specs = param_specs if mu_dtype is not None: mu_specs = jax.tree.map( @@ -1993,3 +2024,100 @@ def _update2(u: Tensor, param: OptParam): partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()), ) return named_chain(**tx) + + +def offload_optimizer( + optimizer: ConfigOr[PartitionedGradientTransformation], + *, + pattern: Union[str, re.Pattern] = ".*", + offload_src: MemoryKind = "device", + offload_dst: MemoryKind = "pinned_host", +) -> PartitionedGradientTransformation: + """Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`. + + Args: + optimizer: The optimizer to offload. + pattern: Regex pattern used to match the path of optimizer states. Fully matched states + will be offloaded. Default to regex that matches all states. + offload_src: Offload-from memory kind. Default to "device". + offload_dst: Offload-to memory kind. Default to "pinned_host". + + Returns: + A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`. + + Raises: + ValueError: when the `update` function of the returned optimizer is called outside of jit + context. + + This function returns a new `PartitionedGradientTransformation` that + 1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function + during state initialization in the trainer. + 2. Copies the matched states to `offload_src` before `optimizer.update` is called. + 3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called. + + The regex pattern is matched against the full path of each optimizer state. An example full + path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the + pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`. + + The .update function of the returned `PartitionedGradientTransformation` must be called within + a jit function. + + Example usage: + ```python + your_opt = adamw_optimizer(...) + offloaded_opt = offload_optimizer(your_opt) + ``` + + When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire + `skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm` + or you will get errors. Correct example: + ``` + offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...))) + ``` + The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer + state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on + `pinned_host` memory space. + """ + optimizer = maybe_instantiate(optimizer) + if offload_src is None or offload_dst is None: + raise ValueError( + "offload_src and offload_dst cannot be None when using optimizer offloading." + ) + + logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst) + + def init_fn(params: NestedOptParam): + return optimizer.init(params) + + def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState: + # TransferToMemoryKind let us change the memory kind of tensors without specifying the full + # sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it, + # it's specified in the API signature. Reference: + # https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220 + # Note: device_put doesn't move everything at once. When we pass a pytree of arrays to + # device_put, each array in the pytree is moved independent of one another. The exact order + # is decided by the latency hiding scheduler. The scheduler will try to overlap the + # transfers of each state with the state update on TPU whenever possible. There is some + # memory spike due the the temporary state in HBM, but the spike is much less than the full + # memory usage of all states. Moreover, when the optimizer is run, all activations are + # released, so we have less memory pressure at that point in time. + return jax.tree.map( + lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst)) + if re.fullmatch(pattern, path) + else tensor, + tree_paths(state), + state, + ) + + def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam): + state = _move_fn(state, offload_src) + updates, state = optimizer.update(updates, state, params) + state = _move_fn(state, offload_dst) + return updates, state + + def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]: + return copy_partition( + optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst + ) + + return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 5ed40098..49aa348e 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -40,6 +40,7 @@ ema, l2_regularizer, lion_optimizer, + offload_optimizer, opt_param_values, param_ema, per_param_scale_by_path, @@ -379,12 +380,25 @@ def _check_dtypes(x, y, z): jax.tree.map(_check_dtypes, init_state, partition_state, update_state) def _test_optimizer(self, optimizer): - params = OptParam( - value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32), - factorization_spec=None, - weight_decay_scale=1.0, - ) - state = optimizer.init(params) + self._test_optimizer_helper(optimizer, True) + self._test_optimizer_helper(optimizer, False) + + def _test_optimizer_helper(self, optimizer, offload): + if offload: + optimizer = offload_optimizer(optimizer) + params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32) + + def create_opt_params(x): + return jax.tree.map( + lambda y: OptParam( + value=y, + factorization_spec=None, + weight_decay_scale=1.0, + ), + x, + ) + + state = optimizer.init(create_opt_params(params)) param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None) state_partition_spec = optimizer.partition(param_spec) @@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree): jax.tree.map(check_partition_spec, state_partition_spec, state) - def compute_loss(x): - return -jax.nn.log_softmax(x)[1] + @jax.jit + def jit_fn(params, state): + def compute_loss(x): + return -jax.nn.log_softmax(x)[1] - loss, grads = jax.value_and_grad(compute_loss)(params.value) - updates, _ = optimizer.update(grads, state=state, params=params) - updated_params = optax.apply_updates(params.value, updates) - new_loss = compute_loss(updated_params) + params = create_opt_params(params) + loss, grads = jax.value_and_grad(compute_loss)(params.value) + updates, _ = optimizer.update(grads, state=state, params=params) + updated_params = optax.apply_updates(params.value, updates) + return loss, compute_loss(updated_params) + + if offload: + self.assertIn( + "TransferToMemoryKind(memory_kind='pinned_host')", + str(jax.make_jaxpr(jit_fn)(params, state)), + ) + loss, new_loss = jit_fn(params, state) self.assertLess(new_loss, loss) @parameterized.product( @@ -788,14 +812,17 @@ def loss_fn(x): config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]), config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]), ), + offload=(True, False), ) - def test_gradient_skipping_and_clipping(self, max_norm, drop_norm): + def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload): clip = skip_and_clip_by_global_norm( inner=_counter(), drop_norm=drop_norm, max_norm=max_norm, grad_norm_ema_decay=0.99, ) + if offload: + clip = offload_optimizer(clip) params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32) state = clip.init(params) init_ema = state.grad_norm_ema @@ -821,7 +848,11 @@ def loss_fn(x): else: is_valid_step = drop_norm is None or g_norm < drop_norm - updates, state = clip.update(grads, state=state, params=params) + @jax.jit + def jit_fn(grads, state, params): + return clip.update(grads, state=state, params=params) + + updates, state = jit_fn(grads, state, params) if is_valid_step: if max_norm is None or g_norm < max_norm: np.testing.assert_allclose(updates, grads, atol=1e-6) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index e952a060..4117137a 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -50,10 +50,10 @@ HybridMeshShape, MeshShape, Nested, - NestedPartitionSpec, NestedTensor, PartitionSpec, Tensor, + TensorSpec, count_model_params, flatten_items, match_regex_rules, @@ -62,9 +62,9 @@ class TrainerState(NamedTuple): - prng_key: Union[Tensor, NestedPartitionSpec] - model: Union[NestedTensor, NestedPartitionSpec] - learner: Union[NestedTensor, NestedPartitionSpec] + prng_key: Union[Tensor, TensorSpec, jax.sharding.NamedSharding] + model: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]] + learner: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]] # pylint: disable-next=too-many-instance-attributes @@ -319,8 +319,8 @@ def __init__( model=self._model_param_specs, learner=self._learner_state_partition_specs, ) - self._trainer_state_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._trainer_state_specs + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs ) # Create evalers, which depend on model_param_partition_specs. self._evalers = {} diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index a25a798a..84c2f04d 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -27,7 +27,7 @@ import types from collections.abc import Mapping, Sequence from enum import Enum -from typing import Any, Callable, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, TypeVar, Union import jax import numpy as np @@ -91,6 +91,13 @@ def __len__(self): return len(self.ici_mesh_shape) +# "device" = Accelerator memory, e.g. HBM. +# "pinned_host" = Page locked memory on CPU, which can be address directly by accelerators by +# direct memory access (DMA). For TPU, "pinned_host" memory layout follows TPU device tile +# layout and usually cannot be zero-copy converted to a CPU-tensor. +MemoryKind = Literal["device", "pinned_host"] + + @dataclasses.dataclass class TensorSpec: """Specification of a Tensor. @@ -101,11 +108,12 @@ class TensorSpec: shape: Sequence[int] dtype: Optional[jnp.dtype] = None mesh_axes: Optional[PartitionSpec] = None + memory_kind: Optional[MemoryKind] = None @property def sharding(self) -> jax.sharding.Sharding: mesh = thread_resources.env.physical_mesh - return jax.sharding.NamedSharding(mesh, self.mesh_axes) + return jax.sharding.NamedSharding(mesh, self.mesh_axes, memory_kind=self.memory_kind) NestedTensorSpec = Optional[Union[TensorSpec, dict[str, Any]]]