Skip to content

Commit

Permalink
Introduce the scale enum flag in Embedding layer for LLM embedding.
Browse files Browse the repository at this point in the history
The activation component should roughly have a magnitude of 1. Since the embedding tensor is
initialized with a scale of `1/sqrt(dim)`, the activation is multiplied by `sqrt(dim)` to
maintain the desired scale. e.g. Gemma [1]
[1] https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80

In addition, unsloth [2] discovered that `sqrt(dim)` needs to be computed in float32.
[2] Sec 3 in https://unsloth.ai/blog/gemma-bugs

TODO(axlearn-team): Use UNIT scale enum for AFM+. This will require re-sweeping
hyperparameters (e.g., learning rate).
  • Loading branch information
ds-hwang committed Jan 8, 2025
1 parent c40b39a commit 8167248
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
30 changes: 30 additions & 0 deletions axlearn/common/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ def test_embed_attend(self, soft_cap_logits, is_training):
ref = soft_cap_logits * jnp.tanh(ref / soft_cap_logits)
assert_allclose(ref, actual_attends)

def test_embed_with_emb_scale(self):
seq_len = 5
vocab_size = 24
hidden_dim = 256

emb = TransformerTextEmbeddings.default_config().set(
name="embed",
dim=hidden_dim,
vocab_size=vocab_size,
)
emb.token_emb.set(scale=emb.token_emb.klass.Scale.UNIT)
layer = emb.instantiate(parent=None)

prng_key = jax.random.PRNGKey(1)
prng_key, init_key, data_key, fwd_key = jax.random.split(prng_key, num=4)
state = layer.initialize_parameters_recursively(init_key)

input_ids = jax.random.randint(data_key, shape=(3, seq_len), minval=1, maxval=vocab_size)
test_inputs = dict(inputs=input_ids)
outputs, _ = module.functional(
layer,
prng_key=fwd_key,
state=state,
inputs=dict(input_batch=test_inputs),
is_training=False,
)

assert_allclose(jnp.mean(outputs), 0.0, atol=0.05)
assert_allclose(jnp.std(outputs), 1.0, atol=0.05)


if __name__ == "__main__":
with utils.numeric_checks(True):
Expand Down
36 changes: 36 additions & 0 deletions axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Basic layers."""

import enum
import math
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -781,6 +782,21 @@ class Embedding(BaseLayer):
Batched map for int in [0, <num_embeddings>) -> <dim> float vector.
"""

class Scale(enum.Enum):
"""Defines the scale method on embedding activations.
Available types:
1. **UNIT**: Scale the activation components to ~1.
The activation component should roughly have a magnitude of 1. Since the embedding tensor is
initialized with a scale of `1/√dim`, the activation is multiplied by `√dim` to
maintain the desired scale. e.g. Gemma [1]
[1]
https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80
"""

UNIT = "unit"

@config_class
class Config(BaseLayer.Config):
"""Configures Embedding."""
Expand All @@ -793,6 +809,8 @@ class Config(BaseLayer.Config):
embedding_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None
# Optional scaling of the embedding activations.
scale: Optional["Embedding.Scale"] = None

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -832,9 +850,27 @@ def forward(self, x: Tensor) -> Tensor:
emb = self.parameters["weight"]
emb = maybe_shard(emb, cfg.embedding_partition_spec)
activation = emb[x]
activation = self._scale(activation)
activation = maybe_shard(activation, cfg.output_partition_spec)
return activation

def _scale(self, x: Tensor) -> Tensor:
"""Scale the activation if needed."""
cfg = self.config
if cfg.scale is None:
return x

# Unsloth [1] discovered that `sqrt(dim)` needs to be computed in float32.
# [1] Sec 3 in https://unsloth.ai/blog/gemma-bugs.html
x_dtype = x.dtype
x = x.astype(jnp.float32)
if cfg.scale == self.Scale.UNIT:
x = x * math.sqrt(x.shape[-1])
else:
raise ValueError(f"Unknown scale {cfg.scale}.")
x = x.astype(x_dtype)
return x

def attend(self, x: Tensor) -> Tensor:
"""Apply query array 'x' to the embedding weight array.
Expand Down
31 changes: 26 additions & 5 deletions axlearn/common/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,11 +1230,10 @@ def test_moving_average(self):

class EmbedTest(parameterized.TestCase):
@staticmethod
def build_embedder(dim, num_embeddings, rng):
cfg = Embedding.default_config()
cfg.dim = dim
cfg.num_embeddings = num_embeddings
cfg.name = "embed"
def build_embedder(dim, num_embeddings, rng, **kwargs):
cfg = Embedding.default_config().set(name="embed", dim=dim, num_embeddings=num_embeddings)
if kwargs:
cfg = cfg.set(**kwargs)
emb = cfg.instantiate(parent=None)
state = emb.initialize_parameters_recursively(rng)
return (emb, state)
Expand All @@ -1249,6 +1248,28 @@ def test_embed_lookup(self, seq_len, dim, num_embeddings, is_training):
)
np.testing.assert_array_equal(state["weight"][ixs], actual_embeds)

def test_embed_with_scale(self):
dim = 256
num_embeddings = 16
prng_key = jax.random.PRNGKey(123)
prng_key, input_key, fwd_key = jax.random.split(prng_key, num=3)
embedder, state = EmbedTest.build_embedder(
dim, num_embeddings, input_key, scale=Embedding.Scale.UNIT
)
batch, seq_len = 5, 8
ixs = jax.random.randint(input_key, minval=0, maxval=num_embeddings, shape=(batch, seq_len))

outputs, _ = F(
embedder,
inputs=(ixs,),
is_training=True,
state=state,
prng_key=fwd_key,
)

assert_allclose(jnp.mean(outputs), 0.0, atol=0.05)
assert_allclose(jnp.std(outputs), 1.0, atol=0.05)

@parameterized.parameters(itertools.product((5, 7), (2, 16), (10, 100), (True, False)))
def test_embed_attend(self, seq_len, dim, num_embeddings, is_training):
rng = jax.random.PRNGKey(1)
Expand Down

0 comments on commit 8167248

Please sign in to comment.