Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the torchax llama405b OOM at model init time #24

Merged
merged 1 commit into from
Jan 11, 2025

Conversation

tengyifei
Copy link
Collaborator

Instead of holding a large global weight_jax array, we hold the meta tensor, and create a local jax array whose size and dtype correspond to that of a shard.

The program still OOMs the host memory space later while compiling the training step but that will be addressed separately.

Instead of holding a large global `weight_jax` array, we hold the
meta tensor, and create a local jax array whose size and dtype
correspond to that of a shard.

The program still OOMs the host memory space later while compiling
the training step but that will be addressed separately.
@tengyifei tengyifei requested a review from qihqi January 10, 2025 18:48
@tengyifei tengyifei merged commit 04cd0a9 into main Jan 11, 2025
6 checks passed
@tengyifei tengyifei deleted the yifeit/torchax-param-oom branch January 26, 2025 07:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants