Skip to content

Commit

Permalink
Fixed over asking bug (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead authored Sep 13, 2022
1 parent d1a2856 commit 7131736
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
7 changes: 7 additions & 0 deletions tests/test_wazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,10 @@ def batch_ask(self):
assert len(x) == 2 * 4
# make sure no dups
assert len(set(x)) == len(x)

def test_overask(self):
key = jax.random.PRNGKey(0)
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10))
for a in ALPHABET:
boa.tell(key, a, 1)
x, _ = boa.ask(key)
18 changes: 14 additions & 4 deletions wazy/asktell.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import jax.numpy as jnp
import numpy as np
from functools import partial
Expand Down Expand Up @@ -122,11 +123,20 @@ def ask(self, key, aq_fxn="ucb", length=None, return_seqs=1):
seq = None
min_idxs = jnp.argsort(jnp.squeeze(bo_loss[-1]))
out_seq = ["".join(decode_seq(batched_v[0][i])) for i in min_idxs]
out_loss = [bo_loss[-1][i] for i in min_idxs if out_seq[i] not in self.seqs]
out_seq = [o for o in out_seq if o not in self.seqs]
out_loss = [bo_loss[-1][i] for i in min_idxs]
filtered_out_loss = [
out_loss[i] for i in range(len(out_seq)) if out_seq[i] not in self.seqs
]
filtered_out_seq = [o for o in out_seq if o not in self.seqs]
if len(filtered_out_seq) < return_seqs:
warnings.warn(
"Not enough unique sequences to return - returning duplicates"
)
filtered_out_seq = out_seq
filtered_out_loss = out_loss
if return_seqs == 1:
return out_seq[0], out_loss[0]
return out_seq[:return_seqs], out_loss[:return_seqs]
return filtered_out_seq[0], filtered_out_loss[0]
return filtered_out_seq[:return_seqs], filtered_out_loss[:return_seqs]

def batch_ask(self, key, N, aq_fxn="ucb", lengths=None, return_seqs=1):
"""Batch asking iteratively asking and telling min value
Expand Down
2 changes: 1 addition & 1 deletion wazy/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.0"
__version__ = "0.10.1"

0 comments on commit 7131736

Please sign in to comment.