Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schedule-free adamw submission in JAX #809

Open
priyakasimbeg opened this issue Oct 31, 2024 · 7 comments
Open

Add schedule-free adamw submission in JAX #809

priyakasimbeg opened this issue Oct 31, 2024 · 7 comments

Comments

@priyakasimbeg
Copy link
Contributor

Description

Currently we have been unable to reproduce the schedule free adamw results with JAX.
There seem to be differences between the optax implementation of schedule-free adamw and the pytorch submission.

@adefazio
Copy link
Contributor

adefazio commented Nov 1, 2024

I can help debug any issues here. Do you have any code you can share? If there are issues with the optax jax implementation I want to get it fixed asap.

@adefazio
Copy link
Contributor

adefazio commented Nov 6, 2024

There are many small differences between the behavior of schedule-free jax wrapper and the original algoperf submission. Some differences I'm aware of:

  • The bias correction in the submission scales the weight decay at early steps. This is slightly faster for fastMRI but doesn't appear to affect any other workloads in my experiments.
  • Weight decay is applied at y in the Jax version. This decay-at-y version is very similar in my experiments, if not slightly better (when testing in PyTorch). The experiments in the schedule-free paper use this decay-at-y version.
  • There is a r=0.5 weighting in the submission version - this seems to make little if any difference in practice (hard to tell due to noise).

So overall I expect the jax wrapper version to give as good results on all problems (maybe slightly slower on fastmrI), so if there is a difference it would be from some sort of bug.

@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Nov 19, 2024

Hi Aaron! thanks for weighing in on this. I seemed to have missed your messages on this thread.

We have a slightly modified version based on the optax code here: https://github.com/priyakasimbeg/algorithmic-efficiency/blob/compare_schedule_free/tests/test_algorithms/schedule_free_adamw/jax/schedule_free_optax.py.
This code adds r and we tested it with 0.75 on our google internal codebase.

I'm working on a test to compare the pytorch and jax implementations side by side on the algoperf github code but the test is still in progress. I can perhaps run a full training run on some of the workloads.
But in the meantime feel free to weigh in again if you spot any other differences

@adefazio
Copy link
Contributor

Ok, I take a look and see if I spot any differences.

@adefazio
Copy link
Contributor

It looks like the z buffer my be initialized with zeros:
https://github.com/priyakasimbeg/algorithmic-efficiency/blob/5556015054e3dda681e2a25e05a2f217d933453d/tests/test_algorithms/schedule_free_adamw/jax/submission.py#L58C51-L59C1
It needs to be initialized the same as the main parameter buffer. I think this line is a copy-paste error from the Jax version of NAdamW and other methods, where all optimizer state is normally initialized at zero.

Suggestion: you might want to set z on the first call to the main optimizer update, that's what we do in the pytorch version.

@adefazio
Copy link
Contributor

@priyakasimbeg Let me know if that initialization issue was the problem.

@priyakasimbeg
Copy link
Contributor Author

Hi Aaron thanks for spotting that!
We did the end-to-end training with an internal codebase using the optax implementation: https://github.com/google-deepmind/optax/blob/3ba9822c2a8d5fa7d046180b2574e108094523b4/optax/contrib/_schedule_free.py. It's not immediately obvious to me how this is initialized but will investigate and report back here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants