Skip to content

Commit

Permalink
Made it possible to turn off pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Aug 7, 2022
1 parent 479808c commit 83c0eb9
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip install wazy@git+https://github.com/ur-whitelab/wazy
## Quickstart
You can use an ask/tell style interface to design a peptide.

We can tell a few examples of sequences we know and their scalar labels. Let's try a simple example where the label is the number of alanines. We'll start by importing and building a `BOAlgorithm` class. *In this example, I re-use the same key for simplicity.*
We can tell a few examples of sequences we know and their scalar labels. Let's try a simple example where the label is the number of alanines. You'll also want your labels to vary from about -5 to 5. We'll start by importing and building a `BOAlgorithm` class. *In this example, I re-use the same key for simplicity.*

```py
import wazy
Expand Down
8 changes: 8 additions & 0 deletions tests/test_wazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,11 @@ def test_ask(self):
x, _ = boa.ask(key, length=5)
assert len(x) == 5
x, _ = boa.ask(key, "max")

def test_ask_nounirep(self):
key = jax.random.PRNGKey(0)
c = wazy.EnsembleBlockConfig(pretrained=False)
boa = wazy.BOAlgorithm(model_config=c)
boa.tell(key, "CCC", 1)
boa.tell(key, "EEE", 0)
x, _ = boa.ask(key)
10 changes: 8 additions & 2 deletions wazy/asktell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ def __init__(self, model_config=None, alg_config=None) -> None:
self._ready = False
self._trained = 0

def _get_reps(self, seq):
if self.mconfig.pretrained:
return get_reps([seq])[0][0]
else:
return encode_seq(seq).flatten()

def tell(self, key, seq, label):
if not self._ready:
key, _ = jax.random.split(key)
self._init(seq, label, key)
self.seqs.append(seq)
self.reps.append(get_reps([seq])[0][0])
self.reps.append(self._get_reps(seq))
self.labels.append(label)

def _maybe_train(self, key):
Expand All @@ -54,7 +60,7 @@ def predict(self, key, seq):
raise Exception("Must call tell once before predict")
self._maybe_train(key)
key = jax.random.split(key)[0]
x = get_reps([seq])[0][0]
x = self._get_reps(seq)
return self.model.infer_t.apply(self.params, key, x, training=False)

def ask(self, key, aq_fxn="ucb", length=None):
Expand Down
9 changes: 5 additions & 4 deletions wazy/e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def model_uncertainty_eval(x, training=False):

def seq_forward(x, training=True): # params is trained mlp params
s = SeqpropBlock()(x)
us = seq2useq(s)
# TODO: What does the flatten line do???
u = differentiable_jax_unirep(us)
# u = s.flatten()
if config.pretrained:
us = seq2useq(s)
u = differentiable_jax_unirep(us)
else:
u = s.flatten()
mean, var, epi_var = model_forward(u, training=training)
# We only use epistemic uncertainty, since this is used in BO
return mean, epi_var
Expand Down
1 change: 1 addition & 0 deletions wazy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class EnsembleBlockConfig:
)
model_number: int = 5
dropout: float = 0.2
pretrained: bool = True


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion wazy/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.1"
__version__ = "0.8.0"

0 comments on commit 83c0eb9

Please sign in to comment.