Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 8, 2025
1 parent 3ae8f9f commit 382e93e
Show file tree
Hide file tree
Showing 61 changed files with 263 additions and 72 deletions.
13 changes: 10 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,15 +3956,22 @@ def policy(prim, *_, **params):
return policy


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*[kqv]_proj"
O_PROJ = r".*o_proj"
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = re.compile("|".join([QKV_PROJ, O_PROJ, CONTEXT])).pattern
FEED_FORWARD = re.compile("|".join([LINEAR1_X, LINEAR2_X])).pattern


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: _SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
68 changes: 66 additions & 2 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,6 +57,7 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
Expand Down Expand Up @@ -3420,7 +3420,7 @@ def f(x, layer_params):
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3875,6 +3875,70 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(_save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj and linear1_0 = 4 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
4,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ model.encoder.context.context.num_layers: 17
model.encoder.context.context.remat_spec['prevent_cse']: False
model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host'
model.encoder.context.context.remat_spec['policy'].offload_src: 'device'
model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.encoder.context.context.num_layers: 1
model.encoder.context.context.remat_spec['prevent_cse']: False
model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host'
model.encoder.context.context.remat_spec['policy'].offload_src: 'device'
model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Loading

0 comments on commit 382e93e

Please sign in to comment.