Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer offloading through weight-only offload #867

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions axlearn/common/factorized_rms_test.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,12 @@
from axlearn.common import factorized_rms
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.optimizer_base import (
NestedOptStateSpec,
Nested,
OptParam,
OptStateSpec,
PartitionedGradientTransformation,
)
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
from axlearn.common.optimizers import with_partition_fn
from axlearn.common.test_utils import TestCase
from axlearn.common.utils import PartitionSpec, flatten_items

@@ -59,7 +60,7 @@ def testParity(self, factored, dtype):

# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
# factorization spec.
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
# Used for `count`.
count_spec = OptStateSpec(
dtype=jnp.int32,
8 changes: 3 additions & 5 deletions axlearn/common/optimizer_base.py
Original file line number Diff line number Diff line change
@@ -16,14 +16,13 @@
- weight_decay_scale: control the weight decay rate.
"""
import dataclasses
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Union

import optax
import typing_extensions

from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.utils import Nested, Tensor, TensorSpec


@dataclasses.dataclass
@@ -66,8 +65,7 @@ def __call__(

# Specification of an optimizer state array.
OptStateSpec = TensorSpec
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]
ruomingp marked this conversation as resolved.
Show resolved Hide resolved


class PartitionedGradientTransformation(NamedTuple):
134 changes: 123 additions & 11 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
@@ -36,10 +36,11 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import schedule, struct
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.factorized_rms import scale_by_factored_rms
from axlearn.common.module import current_context
@@ -51,8 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
MemoryKind,
Nested,
NestedPartitionSpec,
NestedTensor,
NestedTree,
Tensor,
@@ -139,19 +140,40 @@ def update_fn(
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)


def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def copy_partition(
param_specs: Nested[ParameterSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of coupling creation of OptStateSpec and setting of memory_kind, how about having a separate function for setting memory kind?

def set_memory_kind(opt_state_spec: Nested[OptStateSpec], *, pattern, memory_kind):

This allows set_memory_kind to be called multiple times, maybe for different memory kind. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see how set_memory_kind will be different from copy_partition. Signature and implementation will be the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine in the future we have many types of memory kinds, e.g., "remote_host". Then we can do:

opt_state_specs = copy_partition(...)
opt_state_specs = set_memory_kind(..., "pinned_host")
opt_state_specs = set_memory_kind(..., "remote_host")

Copy link
Member Author

@hanzhi713 hanzhi713 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be the same as

opt_state_specs = copy_partition(...)
opt_state_specs = copy_partition(..., "pinned_host")
opt_state_specs = copy_partition(..., "remote_host")

Do you mean that using a separate function is slightly better for readability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that copy_partition also performs the type conversion from Nested[ParameterSpec] to Nested[OptStateSpec].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I can change the type of param_specs in copy_partition to Nested[OptStateSpec] since ParameterSpec is a subclass of OptStateSpec and copy_partition doesn't use any new fields from ParameterSpec. Does this sound good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. SG.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""Creates OptStateSpec from ParameterSpec with possibly a different memory kind.

Args:
param_specs: Nested[ParameterSpec] to copy from.
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
Returns:

A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda param_spec: OptStateSpec(
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
lambda path, param_spec: OptStateSpec(
dtype=param_spec.dtype,
shape=param_spec.shape,
mesh_axes=param_spec.mesh_axes,
memory_kind=memory_kind
if pattern and re.fullmatch(pattern, path)
else param_spec.memory_kind,
),
tree_paths(param_specs),
param_specs,
)


def trace_partition(
base: optax.GradientTransformation,
) -> PartitionedGradientTransformation:
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.TraceState(trace=copy_partition(param_specs))

return with_partition_fn(base, partition_fn)
@@ -160,7 +182,7 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
state: optax.ScaleByAdamState = base.init({})

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.ScaleByAdamState(
count=OptStateSpec(
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
@@ -950,7 +972,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
)
return updates, new_state

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
# Store momentum in accumulator_dtype if it is set and p is not scalar.
if param_spec.shape and accumulator_dtype is not None:
@@ -1412,7 +1434,7 @@ def _is_valid_step(
drop_stats=new_drop_stats,
)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
if use_adaptive_drop_norm:
one = jnp.ones([], jnp.float32)
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
@@ -1571,7 +1593,7 @@ def update_fn(updates, state, params):
)
return updates, ParamEmaState(count=count_inc, ema=new_ema)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return ParamEmaState(
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
ema=copy_partition(param_specs),
@@ -1617,7 +1639,7 @@ def update_fn(updates, state, params=None):
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
mu_specs = param_specs
if mu_dtype is not None:
mu_specs = jax.tree.map(
@@ -1993,3 +2015,93 @@ def _update2(u: Tensor, param: OptParam):
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
)
return named_chain(**tx)


def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
offload_src: MemoryKind = "device",
offload_dst: MemoryKind = "pinned_host",
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.

Args:
optimizer: The optimizer to offload.
pattern: Regex pattern used to match the path of optimizer states. Fully matched states
will be offloaded. Default to regex that matches all states.
offload_src: Offload-from memory kind. Default to "device".
offload_dst: Offload-to memory kind. Default to "pinned_host".

Returns:
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.

Raises:
ValueError: when the `update` function of the returned optimizer is called outside of jit
context.

This function returns a new `PartitionedGradientTransformation` that
1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
during state initialization in the trainer.
2. Copies the matched states to `offload_src` before `optimizer.update` is called.
3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.

The regex pattern is matched against the full path of each optimizer state. An example full
path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`.

The .update function of the returned `PartitionedGradientTransformation` must be called within
a jit function.

Example usage:
```python
your_opt = adamw_optimizer(...)
offloaded_opt = offload_optimizer(your_opt)
```

When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
or you will get errors. Correct example:
```
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
```
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
`pinned_host` memory space.
"""
optimizer = maybe_instantiate(optimizer)
if offload_src is None or offload_dst is None:
raise ValueError(
"offload_src and offload_dst cannot be None when using optimizer offloading."
)

logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst)

def init_fn(params: NestedOptParam):
return optimizer.init(params)

def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
return jax.tree.map(
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
else tensor,
tree_paths(state),
state,
)

def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
state = _move_fn(state, offload_src)
updates, state = optimizer.update(updates, state, params)
state = _move_fn(state, offload_dst)
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
return updates, state

def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return copy_partition(
optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst
)

return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
59 changes: 45 additions & 14 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@
ema,
l2_regularizer,
lion_optimizer,
offload_optimizer,
opt_param_values,
param_ema,
per_param_scale_by_path,
@@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)

def _test_optimizer(self, optimizer):
params = OptParam(
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
factorization_spec=None,
weight_decay_scale=1.0,
)
state = optimizer.init(params)
self._test_optimizer_helper(optimizer, True)
self._test_optimizer_helper(optimizer, False)

def _test_optimizer_helper(self, optimizer, offload):
if offload:
optimizer = offload_optimizer(optimizer)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)

def create_opt_params(x):
return jax.tree.map(
lambda y: OptParam(
value=y,
factorization_spec=None,
weight_decay_scale=1.0,
),
x,
)

state = optimizer.init(create_opt_params(params))

param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
state_partition_spec = optimizer.partition(param_spec)
@@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):

jax.tree.map(check_partition_spec, state_partition_spec, state)

def compute_loss(x):
return -jax.nn.log_softmax(x)[1]
@jax.jit
def jit_fn(params, state):
def compute_loss(x):
return -jax.nn.log_softmax(x)[1]

loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
new_loss = compute_loss(updated_params)
params = create_opt_params(params)
loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

@parameterized.product(
@@ -788,14 +812,17 @@ def loss_fn(x):
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
),
offload=(True, False),
)
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
clip = skip_and_clip_by_global_norm(
inner=_counter(),
drop_norm=drop_norm,
max_norm=max_norm,
grad_norm_ema_decay=0.99,
)
if offload:
clip = offload_optimizer(clip)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
state = clip.init(params)
init_ema = state.grad_norm_ema
@@ -821,7 +848,11 @@ def loss_fn(x):
else:
is_valid_step = drop_norm is None or g_norm < drop_norm

updates, state = clip.update(grads, state=state, params=params)
@jax.jit
def jit_fn(grads, state, params):
return clip.update(grads, state=state, params=params)

updates, state = jit_fn(grads, state, params)
if is_valid_step:
if max_norm is None or g_norm < max_norm:
np.testing.assert_allclose(updates, grads, atol=1e-6)
12 changes: 6 additions & 6 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
@@ -50,10 +50,10 @@
HybridMeshShape,
MeshShape,
Nested,
NestedPartitionSpec,
NestedTensor,
PartitionSpec,
Tensor,
TensorSpec,
count_model_params,
flatten_items,
match_regex_rules,
@@ -62,9 +62,9 @@


class TrainerState(NamedTuple):
prng_key: Union[Tensor, NestedPartitionSpec]
model: Union[NestedTensor, NestedPartitionSpec]
learner: Union[NestedTensor, NestedPartitionSpec]
prng_key: Union[Tensor, TensorSpec, jax.sharding.NamedSharding]
model: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]
learner: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]


# pylint: disable-next=too-many-instance-attributes
@@ -309,8 +309,8 @@ def __init__(
model=self._model_param_specs,
learner=self._learner_state_partition_specs,
)
self._trainer_state_partition_specs = jax.tree.map(
lambda spec: spec.mesh_axes, self._trainer_state_specs
self._trainer_state_partition_specs: TrainerState = jax.tree.map(
lambda spec: spec.sharding, self._trainer_state_specs
)
# Create evalers, which depend on model_param_partition_specs.
self._evalers = {}
Loading
Loading