From 7131736e9973bee746de42fcd0ef0aa3cb168d80 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 13 Sep 2022 16:43:14 -0400 Subject: [PATCH] Fixed over asking bug (#11) --- tests/test_wazy.py | 7 +++++++ wazy/asktell.py | 18 ++++++++++++++---- wazy/version.py | 2 +- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/test_wazy.py b/tests/test_wazy.py index d85ef1c..3a51618 100644 --- a/tests/test_wazy.py +++ b/tests/test_wazy.py @@ -257,3 +257,10 @@ def batch_ask(self): assert len(x) == 2 * 4 # make sure no dups assert len(set(x)) == len(x) + + def test_overask(self): + key = jax.random.PRNGKey(0) + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10)) + for a in ALPHABET: + boa.tell(key, a, 1) + x, _ = boa.ask(key) diff --git a/wazy/asktell.py b/wazy/asktell.py index d285016..bb66366 100644 --- a/wazy/asktell.py +++ b/wazy/asktell.py @@ -1,3 +1,4 @@ +import warnings import jax.numpy as jnp import numpy as np from functools import partial @@ -122,11 +123,20 @@ def ask(self, key, aq_fxn="ucb", length=None, return_seqs=1): seq = None min_idxs = jnp.argsort(jnp.squeeze(bo_loss[-1])) out_seq = ["".join(decode_seq(batched_v[0][i])) for i in min_idxs] - out_loss = [bo_loss[-1][i] for i in min_idxs if out_seq[i] not in self.seqs] - out_seq = [o for o in out_seq if o not in self.seqs] + out_loss = [bo_loss[-1][i] for i in min_idxs] + filtered_out_loss = [ + out_loss[i] for i in range(len(out_seq)) if out_seq[i] not in self.seqs + ] + filtered_out_seq = [o for o in out_seq if o not in self.seqs] + if len(filtered_out_seq) < return_seqs: + warnings.warn( + "Not enough unique sequences to return - returning duplicates" + ) + filtered_out_seq = out_seq + filtered_out_loss = out_loss if return_seqs == 1: - return out_seq[0], out_loss[0] - return out_seq[:return_seqs], out_loss[:return_seqs] + return filtered_out_seq[0], filtered_out_loss[0] + return filtered_out_seq[:return_seqs], filtered_out_loss[:return_seqs] def batch_ask(self, key, N, aq_fxn="ucb", lengths=None, return_seqs=1): """Batch asking iteratively asking and telling min value diff --git a/wazy/version.py b/wazy/version.py index 61fb31c..1f4c4d4 100644 --- a/wazy/version.py +++ b/wazy/version.py @@ -1 +1 @@ -__version__ = "0.10.0" +__version__ = "0.10.1"