Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the scale enum flag in Embedding layer for LLM embedding. #909

Merged
merged 1 commit into from
Jan 8, 2025

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Jan 8, 2025

The activation component should roughly have a magnitude of 1. Since the embedding tensor is initialized with a scale of 1/sqrt(dim), the activation is multiplied by sqrt(dim) to maintain the desired scale. e.g. Gemma [1]
[1] https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80

In addition, unsloth [2] discovered that sqrt(dim) needs to be computed in float32. [2] Sec 3 in https://unsloth.ai/blog/gemma-bugs

TODO(axlearn-team): Use UNIT scale enum for AFM+. This will require re-sweeping hyperparameters (e.g., learning rate).

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 8, 2025 04:10
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jan 8, 2025

@ruomingp Could you review? From 970

https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80
"""

UNIT = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use str values to be more readable.

Suggested change
UNIT = 1
UNIT = "unit"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Done.

The activation component should roughly have a magnitude of 1. Since the embedding tensor is
initialized with a scale of `1/sqrt(dim)`, the activation is multiplied by `sqrt(dim)` to
maintain the desired scale. e.g. Gemma [1]
[1] https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80

In addition, unsloth [2] discovered that `sqrt(dim)` needs to be computed in float32.
[2] Sec 3 in https://unsloth.ai/blog/gemma-bugs

TODO(axlearn-team): Use UNIT scale enum for AFM+. This will require re-sweeping
hyperparameters (e.g., learning rate).
@ds-hwang ds-hwang added this pull request to the merge queue Jan 8, 2025
Merged via the queue into apple:main with commit 2d1fb29 Jan 8, 2025
6 checks passed
@ds-hwang ds-hwang deleted the emb_scale branch January 8, 2025 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants