Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 13, 2025
1 parent 3f36108 commit ed58dd2
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 17 deletions.
13 changes: 10 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4001,15 +4001,22 @@ def _save_and_offload_only_these_names_regex(
)


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*[kqv]_proj"
O_PROJ = r".*o_proj"
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X])


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
68 changes: 66 additions & 2 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,6 +57,7 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
Expand Down Expand Up @@ -3446,7 +3446,7 @@ def f(x, layer_params):
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3901,6 +3901,70 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(_save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj and linear1_0 = 4 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
4,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
11 changes: 2 additions & 9 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseQKVLinear,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
Expand Down Expand Up @@ -190,20 +191,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,7 +270,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
Expand Down
36 changes: 33 additions & 3 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from axlearn.common import causal_lm, config
from axlearn.common.attention import (
SELF_ATTENTION_SAVE_PATTERN,
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
_save_and_offload_only_these_names_regex,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -86,7 +87,6 @@ class Version(enum.Enum):
Version.V3: 5e5,
}


# Mapping from Fuji versions to total number of tokens used in training.
TOTAL_TOKENS = {
Version.V1: {
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_trainer_kwargs(
extended_checkpoint_policies.save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved=None,
names_which_can_be_offloaded=SELF_ATTENTION_SAVE_PATTERN,
names_which_can_be_offloaded=RematRegexSavePatterns.SELF_ATTENTION.value,
offload_src="device",
offload_dst="pinned_host",
)
Expand Down Expand Up @@ -492,6 +492,36 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
_save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit ed58dd2

Please sign in to comment.