Skip to content

Commit

Permalink
Added batched asking (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead authored Aug 9, 2022
1 parent dd3bb2e commit 9fc7f94
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 14 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,19 @@ boa.ask(key, "max", length=5)
# Output
('DAAAA', 3.8262398)
```

### Batching

You can increase the number of returned sequences by using the `batch_ask`, which uses an ad-hoc regret minimization strategy to spread out the proposed sequences:

```py
boa.batch_ask(key, N=3)
# returns 3 seqs
```

and you can add a multiplier to batch sequences (no overhead), but they may be similar

```py
boa.batch_ask(key, N=3, return_seqs = 10)
# returns 30 seqs
```
21 changes: 17 additions & 4 deletions tests/test_wazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,32 +213,45 @@ def x0_gen(key, batch_size, L):
class TestAT(unittest.TestCase):
def test_tell(self):
key = jax.random.PRNGKey(0)
boa = wazy.BOAlgorithm()
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10))
boa.tell(key, "CCC", 1)
boa.tell(key, "GG", 0)

def test_predict(self):
key = jax.random.PRNGKey(0)
boa = wazy.BOAlgorithm()
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10))
boa.tell(key, "CCC", 1)
boa.tell(key, "GG", 0)
boa.predict(key, "FFG")

def test_ask(self):
key = jax.random.PRNGKey(0)
boa = wazy.BOAlgorithm()
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10))
boa.tell(key, "CCC", 1)
boa.tell(key, "GG", 0)
x, _ = boa.ask(key)
assert len(x) == 2
x, _ = boa.ask(key, length=5)
assert len(x) == 5
x, _ = boa.ask(key, "max")
x, v = boa.ask(key, return_seqs=4)
assert len(x) == 4
assert len(v) == 4

def test_ask_nounirep(self):
key = jax.random.PRNGKey(0)
c = wazy.EnsembleBlockConfig(pretrained=False)
boa = wazy.BOAlgorithm(model_config=c)
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10), model_config=c)
boa.tell(key, "CCC", 1)
boa.tell(key, "EEE", 0)
x, _ = boa.ask(key)

def batch_ask(self):
key = jax.random.PRNGKey(0)
boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10))
boa.tell(key, "CCC", 1)
boa.tell(key, "EEE", 0)
x, _ = boa.batch_ask(key, N=2, lengths=[3, 2], return_seqs=4)
assert len(x) == 2 * 4
# make sure no dups
assert len(set(x)) == len(x)
55 changes: 46 additions & 9 deletions wazy/asktell.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def predict(self, key, seq):
x = self._get_reps(seq)
return self.model.infer_t.apply(self.params, key, x, training=False)

def ask(self, key, aq_fxn="ucb", length=None):
def ask(self, key, aq_fxn="ucb", length=None, return_seqs=1):
if not self._ready:
raise Exception("Must call tell once before ask")
if length is None:
Expand Down Expand Up @@ -108,14 +108,51 @@ def ask(self, key, aq_fxn="ucb", length=None):
# find best result, not already measured
seq = None
min_idxs = jnp.argsort(jnp.squeeze(bo_loss[-1]))
i = 0
while seq is None or seq in self.seqs:
top_idx = min_idxs[i]
best_v = batched_v[0][top_idx]
# sample max across logits
seq = "".join(decode_seq(best_v))
i += 1
return seq, -bo_loss[-1][top_idx]
out_seq = ["".join(decode_seq(batched_v[0][i])) for i in min_idxs]
out_loss = [bo_loss[-1][i] for i in min_idxs if out_seq[i] not in self.seqs]
out_seq = [o for o in out_seq if o not in self.seqs]
if return_seqs == 1:
return out_seq[0], out_loss[0]
return out_seq[:return_seqs], out_loss[:return_seqs]

def batch_ask(self, key, N, aq_fxn="ucb", lengths=None, return_seqs=1):
"""Batch asking iteratively asking and telling min value
:param key: :class:`jax.random.PRNGKey` for PRNG
:param N: number of rounds of BO/training
:param aq_fxn: acquisition function "ucb", "ei", "max"
:param lengths: list of lengths of sequences to ask for
:param return_seqs: number of sequences to return per round
:return: list of sequences, list of losses. Number returned is N*return_seqs.
May be less than N*return_seqs if duplicates are proposed.
"""
if lengths is None:
lengths = [None] * N
if len(lengths) != N:
raise Exception("Number of lengths must be same length as N")
split = len(self.reps)
out_s, out_v = [], []
count = 0
for i in range(N):
s, v = self.ask(
key, aq_fxn, lengths[i], return_seqs=self.aconfig.bo_batch_size
)
# make sure to not propose same one which we've seen before
v = [vi for vi, si in zip(v, s) if si not in out_s]
s = [si for si in s if si not in out_s]
# make sure not to propose same one twice
keep = [True for ni, si in enumerate(s) if si not in s[:ni]]
s = [si for si, ki in zip(s, keep) if ki]
v = [vi for vi, ki in zip(v, keep) if ki]
out_s.extend(s[:return_seqs])
out_v.extend(v[:return_seqs])
count += len(s)
for j in range(len(s)):
self.tell(None, s[j], min(self.labels))
# pop off the sequences we've added
self.seqs = self.seqs[:split]
self.labels = self.labels[:split]
self.reps = self.reps[:split]
return out_s, out_v

def _init(self, seq, label, key):
self._ready = True
Expand Down
2 changes: 1 addition & 1 deletion wazy/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.1"
__version__ = "0.9.0"

0 comments on commit 9fc7f94

Please sign in to comment.