diff --git a/README.md b/README.md index b0d4f7c..c7c3f62 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,19 @@ boa.ask(key, "max", length=5) # Output ('DAAAA', 3.8262398) ``` + +### Batching + +You can increase the number of returned sequences by using the `batch_ask`, which uses an ad-hoc regret minimization strategy to spread out the proposed sequences: + +```py +boa.batch_ask(key, N=3) +# returns 3 seqs +``` + +and you can add a multiplier to batch sequences (no overhead), but they may be similar + +```py +boa.batch_ask(key, N=3, return_seqs = 10) +# returns 30 seqs +``` diff --git a/tests/test_wazy.py b/tests/test_wazy.py index 9bae941..9cfb619 100644 --- a/tests/test_wazy.py +++ b/tests/test_wazy.py @@ -213,20 +213,20 @@ def x0_gen(key, batch_size, L): class TestAT(unittest.TestCase): def test_tell(self): key = jax.random.PRNGKey(0) - boa = wazy.BOAlgorithm() + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10)) boa.tell(key, "CCC", 1) boa.tell(key, "GG", 0) def test_predict(self): key = jax.random.PRNGKey(0) - boa = wazy.BOAlgorithm() + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10)) boa.tell(key, "CCC", 1) boa.tell(key, "GG", 0) boa.predict(key, "FFG") def test_ask(self): key = jax.random.PRNGKey(0) - boa = wazy.BOAlgorithm() + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10)) boa.tell(key, "CCC", 1) boa.tell(key, "GG", 0) x, _ = boa.ask(key) @@ -234,11 +234,24 @@ def test_ask(self): x, _ = boa.ask(key, length=5) assert len(x) == 5 x, _ = boa.ask(key, "max") + x, v = boa.ask(key, return_seqs=4) + assert len(x) == 4 + assert len(v) == 4 def test_ask_nounirep(self): key = jax.random.PRNGKey(0) c = wazy.EnsembleBlockConfig(pretrained=False) - boa = wazy.BOAlgorithm(model_config=c) + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10), model_config=c) boa.tell(key, "CCC", 1) boa.tell(key, "EEE", 0) x, _ = boa.ask(key) + + def batch_ask(self): + key = jax.random.PRNGKey(0) + boa = wazy.BOAlgorithm(alg_config=wazy.AlgConfig(bo_epochs=10)) + boa.tell(key, "CCC", 1) + boa.tell(key, "EEE", 0) + x, _ = boa.batch_ask(key, N=2, lengths=[3, 2], return_seqs=4) + assert len(x) == 2 * 4 + # make sure no dups + assert len(set(x)) == len(x) diff --git a/wazy/asktell.py b/wazy/asktell.py index cf4c73a..9b4ef9b 100644 --- a/wazy/asktell.py +++ b/wazy/asktell.py @@ -63,7 +63,7 @@ def predict(self, key, seq): x = self._get_reps(seq) return self.model.infer_t.apply(self.params, key, x, training=False) - def ask(self, key, aq_fxn="ucb", length=None): + def ask(self, key, aq_fxn="ucb", length=None, return_seqs=1): if not self._ready: raise Exception("Must call tell once before ask") if length is None: @@ -108,14 +108,51 @@ def ask(self, key, aq_fxn="ucb", length=None): # find best result, not already measured seq = None min_idxs = jnp.argsort(jnp.squeeze(bo_loss[-1])) - i = 0 - while seq is None or seq in self.seqs: - top_idx = min_idxs[i] - best_v = batched_v[0][top_idx] - # sample max across logits - seq = "".join(decode_seq(best_v)) - i += 1 - return seq, -bo_loss[-1][top_idx] + out_seq = ["".join(decode_seq(batched_v[0][i])) for i in min_idxs] + out_loss = [bo_loss[-1][i] for i in min_idxs if out_seq[i] not in self.seqs] + out_seq = [o for o in out_seq if o not in self.seqs] + if return_seqs == 1: + return out_seq[0], out_loss[0] + return out_seq[:return_seqs], out_loss[:return_seqs] + + def batch_ask(self, key, N, aq_fxn="ucb", lengths=None, return_seqs=1): + """Batch asking iteratively asking and telling min value + :param key: :class:`jax.random.PRNGKey` for PRNG + :param N: number of rounds of BO/training + :param aq_fxn: acquisition function "ucb", "ei", "max" + :param lengths: list of lengths of sequences to ask for + :param return_seqs: number of sequences to return per round + :return: list of sequences, list of losses. Number returned is N*return_seqs. + May be less than N*return_seqs if duplicates are proposed. + """ + if lengths is None: + lengths = [None] * N + if len(lengths) != N: + raise Exception("Number of lengths must be same length as N") + split = len(self.reps) + out_s, out_v = [], [] + count = 0 + for i in range(N): + s, v = self.ask( + key, aq_fxn, lengths[i], return_seqs=self.aconfig.bo_batch_size + ) + # make sure to not propose same one which we've seen before + v = [vi for vi, si in zip(v, s) if si not in out_s] + s = [si for si in s if si not in out_s] + # make sure not to propose same one twice + keep = [True for ni, si in enumerate(s) if si not in s[:ni]] + s = [si for si, ki in zip(s, keep) if ki] + v = [vi for vi, ki in zip(v, keep) if ki] + out_s.extend(s[:return_seqs]) + out_v.extend(v[:return_seqs]) + count += len(s) + for j in range(len(s)): + self.tell(None, s[j], min(self.labels)) + # pop off the sequences we've added + self.seqs = self.seqs[:split] + self.labels = self.labels[:split] + self.reps = self.reps[:split] + return out_s, out_v def _init(self, seq, label, key): self._ready = True diff --git a/wazy/version.py b/wazy/version.py index 8088f75..3e2f46a 100644 --- a/wazy/version.py +++ b/wazy/version.py @@ -1 +1 @@ -__version__ = "0.8.1" +__version__ = "0.9.0"