Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3.11.6 test failed tests.schedulers.test_scheduler_flax.FlaxDDPMSchedulerTest.test_full_loop_no_noise #138

Open
SamuelMarks opened this issue Dec 19, 2024 · 7 comments

Comments

@SamuelMarks
Copy link

tests/schedulers/test_scheduler_flax.py:304 (FlaxDDPMSchedulerTest.test_full_loop_no_noise)
Array(3.7847595, dtype=float32) != 0.01

Expected :0.01
Actual   :Array(3.7847595, dtype=float32)
self = <test_scheduler_flax.FlaxDDPMSchedulerTest testMethod=test_full_loop_no_noise>

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)
        state = scheduler.create_state()
    
        num_trained_timesteps = len(scheduler)
    
        model = self.dummy_model()
        sample = self.dummy_sample_deter
        key1, key2 = random.split(random.PRNGKey(0))
    
        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
            residual = model(sample, t)
    
            # 2. predict previous mean of sample x_t-1
            output = scheduler.step(state, residual, t, sample, key1)
            pred_prev_sample = output.prev_sample
            state = output.state
            key1, key2 = random.split(key2)
    
            # if t > 0:
            #     noise = self.dummy_sample_deter
            #     variance = scheduler.get_variance(t) ** (0.5) * noise
            #
            # sample = pred_prev_sample + variance
            sample = pred_prev_sample
    
        result_sum = jnp.sum(jnp.abs(sample))
        result_mean = jnp.mean(jnp.abs(sample))
    
        if jax_device == "tpu":
            assert abs(result_sum - 251.26245) < 1e-2
            assert abs(result_mean - 0.32716465) < 1e-3
        else:
>           assert abs(result_sum - 255.1113) < 1e-2
E           assert Array(3.7847595, dtype=float32) < 0.01
E            +  where Array(3.7847595, dtype=float32) = abs((Array(251.32654, dtype=float32) - 255.1113))

schedulers/test_scheduler_flax.py:341: AssertionError

Running this without a TPU or GPU; but an M3 Pro.

Planning on going through all your tests and dependencies until 3.10, 3.11, 3.12, 3.13 are supported in addition to your existent 3.8 & 3.9 support.

PS: Your grain-nightly dependency doesn't seem to support 3.8, 3.9:

ERROR: Ignored the following versions that require a different python version: 0.0.1 Requires-Python >=3.10; 0.0.2 Requires-Python >=3.10; 0.0.3 Requires-Python >=3.10; 0.0.4 Requires-Python >=3.10
ERROR: Could not find a version that satisfies the requirement grain-nightly (from versions: none)
ERROR: No matching distribution found for grain-nightly

Is your setup.py up-to-date? - What Python [CPython] versions are you testing on?

@entrpn
Copy link
Collaborator

entrpn commented Dec 19, 2024

@SamuelMarks looking at setup.py I do see it needs to be updated. I run Python 3.10, currently running 3.10.15.

The tests/schedulers/test_scheduler_flax.py come over from diffusers when we forked. I don't maintain them much, the tests we write are located here

@aireenmei can you help answer the grain question.

@aireenmei
Copy link
Collaborator

From the wheels in PyPI, grain-nightly should support python 3.9-3.12 https://pypi.org/project/grain-nightly/#files

@SamuelMarks
Copy link
Author

Confirmed similar issue on Python 3.9.21 on x86_64 Linux running pip 24.3.1, wheel 0.45.1, and setuptools 75.6.0.

$ python -m pip install -r requirements.txt 
Collecting git+https://github.com/mlperf/logging.git (from -r requirements.txt (line 25))
  Cloning https://github.com/mlperf/logging.git to /tmp/pip-req-build-8p__ik31
  Running command git clone --filter=blob:none --quiet https://github.com/mlperf/logging.git /tmp/pip-req-build-8p__ik31
  Resolved https://github.com/mlperf/logging.git to commit eb9e1a39bc313d964e9c1955d76384a6f3a731d3
  Preparing metadata (setup.py) ... done
Collecting jax>=0.4.30 (from -r requirements.txt (line 1))
  Downloading jax-0.4.30-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib>=0.4.30 (from -r requirements.txt (line 2))
  Downloading jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl.metadata (1.0 kB)
Collecting grain-nightly (from -r requirements.txt (line 3))
  Downloading grain_nightly-0.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting google-cloud-storage==2.17.0 (from -r requirements.txt (line 4))
  Downloading google_cloud_storage-2.17.0-py2.py3-none-any.whl.metadata (6.6 kB)
Collecting absl-py (from -r requirements.txt (line 5))
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting datasets (from -r requirements.txt (line 6))
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
ERROR: Ignored the following yanked versions: 0.6.5, 0.7.1, 0.7.3
ERROR: Ignored the following versions that require a different python version: 0.0.1 Requires-Python >=3.10; 0.0.2 Requires-Python >=3.10; 0.0.3 Requires-Python >=3.10; 0.0.4 Requires-Python >=3.10; 0.10.0 Requires-Python >=3.10; 0.10.1 Requires-Python >=3.10; 0.10.2 Requires-Python >=3.10; 0.4.31 Requires-Python >=3.10; 0.4.32 Requires-Python >=3.10; 0.4.33 Requires-Python >=3.10; 0.4.34 Requires-Python >=3.10; 0.4.35 Requires-Python >=3.10; 0.4.36 Requires-Python >=3.10; 0.4.37 Requires-Python >=3.10; 0.4.38 Requires-Python >=3.10; 0.9.0 Requires-Python >=3.10
ERROR: Could not find a version that satisfies the requirement flax>=0.10.2 (from versions: 0.1.0rc1, 0.1.0rc2, 0.1.0, 0.2.0, 0.2.1, 0.2.2, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.4.0, 0.4.1, 0.4.2, 0.5.0, 0.5.1, 0.5.2, 0.5.3, 0.6.0, 0.6.1, 0.6.2, 0.6.3, 0.6.4, 0.6.6, 0.6.7, 0.6.8, 0.6.9, 0.6.10, 0.6.11, 0.7.0, 0.7.2, 0.7.4, 0.7.5, 0.8.0, 0.8.1, 0.8.2, 0.8.3, 0.8.4, 0.8.5)
ERROR: No matching distribution found for flax>=0.10.2

This time with flax.

flax of version 0.8.5 installs fine when I do a python -m pip install flax; but looks like you need 0.10.2 or higher.

What versions of Python would you like to support? - Happy to send PRs everywhere to enable this. - Preferably the list would include Python 3.13, 3.12, and 3.11.

@entrpn
Copy link
Collaborator

entrpn commented Dec 26, 2024

Ideally we want to support python 3.10 and up. Thank you for your help.

@SamuelMarks
Copy link
Author

Ok let me see what I can do. Latest update (testing on an M3 Pro):

No error running tests on CPython 3.10.16.

Looks like my encouragement helped:
google/grain#662

Now as for 3.11.11 & 3.12.8, this is the error:

FAILED [ 33%]
tests/schedulers/test_scheduler_flax.py:304 (FlaxDDPMSchedulerTest.test_full_loop_no_noise)
Array(3.7847595, dtype=float32) != 0.01

Expected :0.01
Actual   :Array(3.7847595, dtype=float32)
<Click to see difference>

self = <test_scheduler_flax.FlaxDDPMSchedulerTest testMethod=test_full_loop_no_noise>

(i.e., same as this issue)

BTW: Finally 3.13.1 failed to install requirements:

ERROR: No matching distribution found for torch>=2.3.1

Relevant PR for this last: pytorch/pytorch#130249

@entrpn
Copy link
Collaborator

entrpn commented Jan 28, 2025

@SamuelMarks the failure for the scheduler is due to the jax's recent changes on enabling jax_threefry_partitionable by default, causing the random numbers generated to be different see jax-ml/jax#18480

I'm making a fix for this.

@SamuelMarks
Copy link
Author

Great. Also it looks like the next PyTorch release is scheduled for the morrow; so should be a good time to cut a new release with 3.13 support also.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants