diff --git a/README.md b/README.md index cd2871f..c42b6d0 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ pip install wazy@git+https://github.com/ur-whitelab/wazy ## Quickstart You can use an ask/tell style interface to design a peptide. -We can tell a few examples of sequences we know and their scalar labels. Let's try a simple example where the label is the number of alanines. We'll start by importing and building a `BOAlgorithm` class. *In this example, I re-use the same key for simplicity.* +We can tell a few examples of sequences we know and their scalar labels. Let's try a simple example where the label is the number of alanines. You'll also want your labels to vary from about -5 to 5. We'll start by importing and building a `BOAlgorithm` class. *In this example, I re-use the same key for simplicity.* ```py import wazy diff --git a/tests/test_wazy.py b/tests/test_wazy.py index 62bcb9c..9bae941 100644 --- a/tests/test_wazy.py +++ b/tests/test_wazy.py @@ -234,3 +234,11 @@ def test_ask(self): x, _ = boa.ask(key, length=5) assert len(x) == 5 x, _ = boa.ask(key, "max") + + def test_ask_nounirep(self): + key = jax.random.PRNGKey(0) + c = wazy.EnsembleBlockConfig(pretrained=False) + boa = wazy.BOAlgorithm(model_config=c) + boa.tell(key, "CCC", 1) + boa.tell(key, "EEE", 0) + x, _ = boa.ask(key) diff --git a/wazy/asktell.py b/wazy/asktell.py index bd1b61f..cf4c73a 100644 --- a/wazy/asktell.py +++ b/wazy/asktell.py @@ -28,12 +28,18 @@ def __init__(self, model_config=None, alg_config=None) -> None: self._ready = False self._trained = 0 + def _get_reps(self, seq): + if self.mconfig.pretrained: + return get_reps([seq])[0][0] + else: + return encode_seq(seq).flatten() + def tell(self, key, seq, label): if not self._ready: key, _ = jax.random.split(key) self._init(seq, label, key) self.seqs.append(seq) - self.reps.append(get_reps([seq])[0][0]) + self.reps.append(self._get_reps(seq)) self.labels.append(label) def _maybe_train(self, key): @@ -54,7 +60,7 @@ def predict(self, key, seq): raise Exception("Must call tell once before predict") self._maybe_train(key) key = jax.random.split(key)[0] - x = get_reps([seq])[0][0] + x = self._get_reps(seq) return self.model.infer_t.apply(self.params, key, x, training=False) def ask(self, key, aq_fxn="ucb", length=None): diff --git a/wazy/e2e.py b/wazy/e2e.py index 74093f8..6ee77da 100644 --- a/wazy/e2e.py +++ b/wazy/e2e.py @@ -41,10 +41,11 @@ def model_uncertainty_eval(x, training=False): def seq_forward(x, training=True): # params is trained mlp params s = SeqpropBlock()(x) - us = seq2useq(s) - # TODO: What does the flatten line do??? - u = differentiable_jax_unirep(us) - # u = s.flatten() + if config.pretrained: + us = seq2useq(s) + u = differentiable_jax_unirep(us) + else: + u = s.flatten() mean, var, epi_var = model_forward(u, training=training) # We only use epistemic uncertainty, since this is used in BO return mean, epi_var diff --git a/wazy/mlp.py b/wazy/mlp.py index e3bf26c..9835c09 100644 --- a/wazy/mlp.py +++ b/wazy/mlp.py @@ -23,6 +23,7 @@ class EnsembleBlockConfig: ) model_number: int = 5 dropout: float = 0.2 + pretrained: bool = True @dataclass diff --git a/wazy/version.py b/wazy/version.py index a5f830a..777f190 100644 --- a/wazy/version.py +++ b/wazy/version.py @@ -1 +1 @@ -__version__ = "0.7.1" +__version__ = "0.8.0"