-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizer offloading through weight-only offload #867
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
axlearn/common/optimizers.py
Outdated
Only wrap the optimizer that you actually want to offload with this function to avoid | ||
unneseccary overhead. This is usually the optimizer that occupies the most HBM. For example, | ||
when you have chained optimizers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the overhead come from? Is it from the states of clip_by_global_norm
being offloaded? If so, could we use regular expressions to specify which states to offload?
axlearn/common/optimizers.py
Outdated
state = jax.device_put(state, TransferToMemoryKind(offload_src)) | ||
updates, state = optimizer.update(updates, state, params) | ||
state = jax.device_put(state, TransferToMemoryKind(offload_dst), donate=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need explicit device_put
calls here? Is it enough to specify the partition spec with the right memory_kind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't tested with_sharding_constraint, if you are refering to that. However, specifying full sharding requires us to store the sharding when the partition fn is called, which is not preferred by John (see internal comments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether we need with_sharding_constraint
, since jit/pjit should already apply shardings returned by partition_fn
to the states. What happens without these device_put
calls?
If necessary, we can always invoke partition_fn
here to compute the sharding on the fly (instead of storing them) and apply with_sharding_constraint
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A question about device_put...
axlearn/common/optimizers.py
Outdated
state = jax.device_put(state, TransferToMemoryKind(offload_src)) | ||
updates, state = optimizer.update(updates, state, params) | ||
state = jax.device_put(state, TransferToMemoryKind(offload_dst), donate=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether we need with_sharding_constraint
, since jit/pjit should already apply shardings returned by partition_fn
to the states. What happens without these device_put
calls?
If necessary, we can always invoke partition_fn
here to compute the sharding on the fly (instead of storing them) and apply with_sharding_constraint
.
Before the optimizer can be invoked, the offloaded optimizer states need to be transferred to device memory space. If we remove these device_put calls, we will get errors like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification on device_put
calls. Could you add a comment on why it's necessary? Also two suggestions...
axlearn/common/optimizers.py
Outdated
def copy_partition( | ||
param_specs: Nested[ParameterSpec], | ||
*, | ||
pattern: Union[None, str, re.Pattern] = None, | ||
memory_kind: Optional[MemoryKind] = None, | ||
) -> Nested[OptStateSpec]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of coupling creation of OptStateSpec
and setting of memory_kind
, how about having a separate function for setting memory kind?
def set_memory_kind(opt_state_spec: Nested[OptStateSpec], *, pattern, memory_kind):
This allows set_memory_kind
to be called multiple times, maybe for different memory kind. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not see how set_memory_kind
will be different from copy_partition
. Signature and implementation will be the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imagine in the future we have many types of memory kinds, e.g., "remote_host". Then we can do:
opt_state_specs = copy_partition(...)
opt_state_specs = set_memory_kind(..., "pinned_host")
opt_state_specs = set_memory_kind(..., "remote_host")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be the same as
opt_state_specs = copy_partition(...)
opt_state_specs = copy_partition(..., "pinned_host")
opt_state_specs = copy_partition(..., "remote_host")
Do you mean that using a separate function is slightly better for readability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference is that copy_partition
also performs the type conversion from Nested[ParameterSpec]
to Nested[OptStateSpec]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I can change the type of param_specs in copy_partition
to Nested[OptStateSpec]
since ParameterSpec
is a subclass of OptStateSpec
and copy_partition
doesn't use any new fields from ParameterSpec
. Does this sound good?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. SG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
pattern: Regex to match the full path of each spec. Matched specs will have their memory | ||
kind replaced with `memory_kind`. | ||
memory_kind: New memory kind. Default to None. | ||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns: | |
Returns: |
This PR requires jax >= 0.4.34, != 0.4.35, >=0.4.36: it works on jax 0.4.34, but is broken on jax 0.4.35 due to libtpu bug. It worked on nightly jax 0.4.36 as of 10/30.
This PR represents effort to enable optimizer offloading. The approach we use in this PR is weight-only offloading, which is based on similar building blocks as activation offloading (aka remat offload). When offloading is enabled, optimizer states are stored on CPU pinned memory. Before apply optimizer to calculate updates, optimizer states are moved from CPU memory to HBM via
jax.device_put
. The new optimizer states are moved back from HBM to CPU.An alternative approach to this PR is host computation. Host computation means that optimizer transformations are computed on CPU. Before the start of the computation, gradients and weights are transferred to CPU, and after the computation, their new values are transferred back to HBM. This method has lower HBM footprint, but it's much 2x ~ 3x slower due to slow CPU computation. Also, it's very buggy.
TLDR: to be merged after upgrading jax to 0.4.36.