From 23e90a36011c8b08941cab507ce54a20ebc692a2 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Thu, 14 Mar 2024 01:22:28 -0700 Subject: [PATCH] Raise error if `result_archive` and `archive` have different fields (#461) ## Description In the scheduler, we did not previously check that result_archive and archive had the same fields. This could lead to confusing errors where we insert certain data into the archive, but it fails to insert into the result_archive. This PR adds a check so that this does not happen. It also fixes some other scheduler documentation. ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go --- HISTORY.md | 2 ++ ribs/schedulers/_bandit_scheduler.py | 19 +++++++++--- ribs/schedulers/_scheduler.py | 19 +++++++++--- tests/schedulers/scheduler_test.py | 44 ++++++++++++++++++++++++---- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index bf46f29fa..314e46b2a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -11,6 +11,8 @@ #### Improvements - Add qd score to lunar lander example ({pr}`458`) +- Raise error if `result_archive` and `archive` have different fields + ({pr}`461`) #### Documentation diff --git a/ribs/schedulers/_bandit_scheduler.py b/ribs/schedulers/_bandit_scheduler.py index 9a5785b6a..b382da086 100644 --- a/ribs/schedulers/_bandit_scheduler.py +++ b/ribs/schedulers/_bandit_scheduler.py @@ -58,10 +58,10 @@ class BanditScheduler: since it is significantly faster. result_archive (ribs.archives.ArchiveBase): In some algorithms, such as CMA-MAE, the archive does not store all the best-performing - solutions. The `result_archive` is a secondary archive where we can - store all the best-performing solutions. + solutions. The ``result_archive`` is a secondary archive where we + can store all the best-performing solutions. Raises: - TypeError: The `emitter_pool` argument was not a list of emitters. + TypeError: The ``emitter_pool`` argument was not a list of emitters. ValueError: Number of active emitters is less than one. ValueError: Less emitters in the pool than the number of active emitters. @@ -69,7 +69,11 @@ class BanditScheduler: dimensions. ValueError: The same emitter instance was passed in multiple times. Each emitter should be a unique instance (see the warning above). - ValueError: Invalid value for `add_mode`. + ValueError: Invalid value for ``add_mode``. + ValueError: The ``result_archive`` and ``archive`` are the same object + (``result_archive`` should not be passed in in this case). + ValueError: The ``result_archive`` and ``archive`` have different + fields. """ def __init__(self, @@ -128,6 +132,13 @@ def __init__(self, "defaults to be the same as `archive` if you pass " "`result_archive=None`") + if (result_archive is not None and + set(archive.field_list) != set(result_archive.field_list)): + raise ValueError("`archive` and `result_archive` should have the " + "same set of fields. This may be the result of " + "passing extra_fields to archive but not to " + "result_archive.") + self._archive = archive self._emitter_pool = np.array(emitter_pool) self._num_active = num_active diff --git a/ribs/schedulers/_scheduler.py b/ribs/schedulers/_scheduler.py index 1f89c7655..db80443ad 100644 --- a/ribs/schedulers/_scheduler.py +++ b/ribs/schedulers/_scheduler.py @@ -39,16 +39,20 @@ class Scheduler: since it is significantly faster. result_archive (ribs.archives.ArchiveBase): In some algorithms, such as CMA-MAE, the archive does not store all the best-performing - solutions. The `result_archive` is a secondary archive where we can - store all the best-performing solutions. + solutions. The ``result_archive`` is a secondary archive where we + can store all the best-performing solutions. Raises: - TypeError: The `emitters` argument was not a list of emitters. + TypeError: The ``emitters`` argument was not a list of emitters. ValueError: The emitters passed in do not have the same solution dimensions. ValueError: There is no emitter passed in. ValueError: The same emitter instance was passed in multiple times. Each emitter should be a unique instance (see the warning above). - ValueError: Invalid value for `add_mode`. + ValueError: Invalid value for ``add_mode``. + ValueError: The ``result_archive`` and ``archive`` are the same object + (``result_archive`` should not be passed in in this case). + ValueError: The ``result_archive`` and ``archive`` have different + fields. """ def __init__(self, @@ -96,6 +100,13 @@ def __init__(self, "defaults to be the same as `archive` if you pass " "`result_archive=None`") + if (result_archive is not None and + set(archive.field_list) != set(result_archive.field_list)): + raise ValueError("`archive` and `result_archive` should have the " + "same set of fields. This may be the result of " + "passing extra_fields to archive but not to " + "result_archive.") + self._archive = archive self._emitters = emitters self._add_mode = add_mode diff --git a/tests/schedulers/scheduler_test.py b/tests/schedulers/scheduler_test.py index beaf39a71..cf505690d 100644 --- a/tests/schedulers/scheduler_test.py +++ b/tests/schedulers/scheduler_test.py @@ -83,15 +83,15 @@ def test_ask_fails_when_called_twice(scheduler_fixture): scheduler.ask() -@pytest.mark.parametrize("archive_type", ["Scheduler", "BanditScheduler"]) -def test_warn_nothing_added_to_archive(archive_type): +@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"]) +def test_warn_nothing_added_to_archive(scheduler_type): archive = GridArchive(solution_dim=2, dims=[100, 100], ranges=[(-1, 1), (-1, 1)], threshold_min=1.0, learning_rate=1.0) emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)] - if archive_type == "Scheduler": + if scheduler_type == "Scheduler": scheduler = Scheduler(archive, emitters) else: scheduler = BanditScheduler(archive, emitters, 1) @@ -106,8 +106,8 @@ def test_warn_nothing_added_to_archive(archive_type): ) -@pytest.mark.parametrize("archive_type", ["Scheduler", "BanditScheduler"]) -def test_warn_nothing_added_to_result_archive(archive_type): +@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"]) +def test_warn_nothing_added_to_result_archive(scheduler_type): archive = GridArchive(solution_dim=2, dims=[100, 100], ranges=[(-1, 1), (-1, 1)], @@ -119,7 +119,7 @@ def test_warn_nothing_added_to_result_archive(archive_type): threshold_min=10.0, learning_rate=1.0) emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)] - if archive_type == "Scheduler": + if scheduler_type == "Scheduler": scheduler = Scheduler( archive, emitters, @@ -143,6 +143,38 @@ def test_warn_nothing_added_to_result_archive(archive_type): ) +@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"]) +def test_result_archive_mismatch_fields(scheduler_type): + archive = GridArchive(solution_dim=2, + dims=[100, 100], + ranges=[(-1, 1), (-1, 1)], + threshold_min=-np.inf, + learning_rate=1.0, + extra_fields={ + "metadata": ((), object), + "square": ((2, 2), np.int32) + }) + result_archive = GridArchive(solution_dim=2, + dims=[100, 100], + ranges=[(-1, 1), (-1, 1)]) + emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)] + + with pytest.raises(ValueError): + if scheduler_type == "Scheduler": + Scheduler( + archive, + emitters, + result_archive=result_archive, + ) + else: + BanditScheduler( + archive, + emitters, + 1, + result_archive=result_archive, + ) + + def test_tell_inserts_solutions_into_archive(add_mode): batch_size = 4 archive = GridArchive(solution_dim=2,