Skip to content

Commit

Permalink
Raise error if result_archive and archive have different fields (#…
Browse files Browse the repository at this point in the history
…461)

## Description

<!-- Provide a brief description of the PR's purpose here. -->

In the scheduler, we did not previously check that result_archive and
archive had the same fields. This could lead to confusing errors where
we insert certain data into the archive, but it fails to insert into the
result_archive. This PR adds a check so that this does not happen. It
also fixes some other scheduler documentation.

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
btjanaka committed Mar 14, 2024

Verified

This commit was signed with the committer’s verified signature.
btjanaka Bryon Tjanaka
1 parent 45b1ab7 commit 23e90a3
Showing 4 changed files with 70 additions and 14 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
#### Improvements

- Add qd score to lunar lander example ({pr}`458`)
- Raise error if `result_archive` and `archive` have different fields
({pr}`461`)

#### Documentation

19 changes: 15 additions & 4 deletions ribs/schedulers/_bandit_scheduler.py
Original file line number Diff line number Diff line change
@@ -58,18 +58,22 @@ class BanditScheduler:
since it is significantly faster.
result_archive (ribs.archives.ArchiveBase): In some algorithms, such as
CMA-MAE, the archive does not store all the best-performing
solutions. The `result_archive` is a secondary archive where we can
store all the best-performing solutions.
solutions. The ``result_archive`` is a secondary archive where we
can store all the best-performing solutions.
Raises:
TypeError: The `emitter_pool` argument was not a list of emitters.
TypeError: The ``emitter_pool`` argument was not a list of emitters.
ValueError: Number of active emitters is less than one.
ValueError: Less emitters in the pool than the number of active
emitters.
ValueError: The emitters passed in do not have the same solution
dimensions.
ValueError: The same emitter instance was passed in multiple times.
Each emitter should be a unique instance (see the warning above).
ValueError: Invalid value for `add_mode`.
ValueError: Invalid value for ``add_mode``.
ValueError: The ``result_archive`` and ``archive`` are the same object
(``result_archive`` should not be passed in in this case).
ValueError: The ``result_archive`` and ``archive`` have different
fields.
"""

def __init__(self,
@@ -128,6 +132,13 @@ def __init__(self,
"defaults to be the same as `archive` if you pass "
"`result_archive=None`")

if (result_archive is not None and
set(archive.field_list) != set(result_archive.field_list)):
raise ValueError("`archive` and `result_archive` should have the "
"same set of fields. This may be the result of "
"passing extra_fields to archive but not to "
"result_archive.")

self._archive = archive
self._emitter_pool = np.array(emitter_pool)
self._num_active = num_active
19 changes: 15 additions & 4 deletions ribs/schedulers/_scheduler.py
Original file line number Diff line number Diff line change
@@ -39,16 +39,20 @@ class Scheduler:
since it is significantly faster.
result_archive (ribs.archives.ArchiveBase): In some algorithms, such as
CMA-MAE, the archive does not store all the best-performing
solutions. The `result_archive` is a secondary archive where we can
store all the best-performing solutions.
solutions. The ``result_archive`` is a secondary archive where we
can store all the best-performing solutions.
Raises:
TypeError: The `emitters` argument was not a list of emitters.
TypeError: The ``emitters`` argument was not a list of emitters.
ValueError: The emitters passed in do not have the same solution
dimensions.
ValueError: There is no emitter passed in.
ValueError: The same emitter instance was passed in multiple times. Each
emitter should be a unique instance (see the warning above).
ValueError: Invalid value for `add_mode`.
ValueError: Invalid value for ``add_mode``.
ValueError: The ``result_archive`` and ``archive`` are the same object
(``result_archive`` should not be passed in in this case).
ValueError: The ``result_archive`` and ``archive`` have different
fields.
"""

def __init__(self,
@@ -96,6 +100,13 @@ def __init__(self,
"defaults to be the same as `archive` if you pass "
"`result_archive=None`")

if (result_archive is not None and
set(archive.field_list) != set(result_archive.field_list)):
raise ValueError("`archive` and `result_archive` should have the "
"same set of fields. This may be the result of "
"passing extra_fields to archive but not to "
"result_archive.")

self._archive = archive
self._emitters = emitters
self._add_mode = add_mode
44 changes: 38 additions & 6 deletions tests/schedulers/scheduler_test.py
Original file line number Diff line number Diff line change
@@ -83,15 +83,15 @@ def test_ask_fails_when_called_twice(scheduler_fixture):
scheduler.ask()


@pytest.mark.parametrize("archive_type", ["Scheduler", "BanditScheduler"])
def test_warn_nothing_added_to_archive(archive_type):
@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"])
def test_warn_nothing_added_to_archive(scheduler_type):
archive = GridArchive(solution_dim=2,
dims=[100, 100],
ranges=[(-1, 1), (-1, 1)],
threshold_min=1.0,
learning_rate=1.0)
emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)]
if archive_type == "Scheduler":
if scheduler_type == "Scheduler":
scheduler = Scheduler(archive, emitters)
else:
scheduler = BanditScheduler(archive, emitters, 1)
@@ -106,8 +106,8 @@ def test_warn_nothing_added_to_archive(archive_type):
)


@pytest.mark.parametrize("archive_type", ["Scheduler", "BanditScheduler"])
def test_warn_nothing_added_to_result_archive(archive_type):
@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"])
def test_warn_nothing_added_to_result_archive(scheduler_type):
archive = GridArchive(solution_dim=2,
dims=[100, 100],
ranges=[(-1, 1), (-1, 1)],
@@ -119,7 +119,7 @@ def test_warn_nothing_added_to_result_archive(archive_type):
threshold_min=10.0,
learning_rate=1.0)
emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)]
if archive_type == "Scheduler":
if scheduler_type == "Scheduler":
scheduler = Scheduler(
archive,
emitters,
@@ -143,6 +143,38 @@ def test_warn_nothing_added_to_result_archive(archive_type):
)


@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"])
def test_result_archive_mismatch_fields(scheduler_type):
archive = GridArchive(solution_dim=2,
dims=[100, 100],
ranges=[(-1, 1), (-1, 1)],
threshold_min=-np.inf,
learning_rate=1.0,
extra_fields={
"metadata": ((), object),
"square": ((2, 2), np.int32)
})
result_archive = GridArchive(solution_dim=2,
dims=[100, 100],
ranges=[(-1, 1), (-1, 1)])
emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)]

with pytest.raises(ValueError):
if scheduler_type == "Scheduler":
Scheduler(
archive,
emitters,
result_archive=result_archive,
)
else:
BanditScheduler(
archive,
emitters,
1,
result_archive=result_archive,
)


def test_tell_inserts_solutions_into_archive(add_mode):
batch_size = 4
archive = GridArchive(solution_dim=2,

0 comments on commit 23e90a3

Please sign in to comment.