Skip to content

Commit

Permalink
DOCS: add links to noSM and SM 60M ckpts
Browse files Browse the repository at this point in the history
  • Loading branch information
anmorgunov committed May 29, 2024
1 parent 84b40f7 commit f66e682
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 59 deletions.
2 changes: 2 additions & 0 deletions DirectMultiStep/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def parse_epoch_step(filename: str):
checkpoints.sort(key=lambda ckpt: parse_epoch_step(ckpt.name), reverse=True)
return checkpoints[0] if checkpoints else None



if __name__ == "__main__":
train_path = Path(__file__).resolve().parent / "Data" / "Training"
run_name = "moe_3x2_3x3_002_local"
Expand Down
19 changes: 4 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,6 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![codecov](https://codecov.io/gh/batistagroup/DirectMultiStep/graph/badge.svg?token=2G1x86tsjc)](https://codecov.io/gh/batistagroup/DirectMultiStep)

Code coverage with tests:

```bash
Name Stmts Miss Cover
-----------------------------------------
Data/Dataset.py 157 114 27%
Utils/PreProcess.py 60 8 87%
-----------------------------------------
TOTAL 217 122 44%
```

## Overview

The preprint for this work is posted on [arXiv](https://arxiv.org/abs/2405.13983).
Expand Down Expand Up @@ -45,15 +34,15 @@ Finally, we provide [assess_single.py](/assess_single.py) which allows to run ou

To use the tutorials, simply move/copy them to the root directory. This is necessary because the notebooks use relative imports.

- [Tutorials/Basic_Usage.ipynb](/Tutorials/Basic_Usage.ipynb) walks you through how to input your compounds, steps, and starting materials. Visualization of routes in PDF is shown.
- [Tutorials/Basic_Usage.ipynb](/Tutorials/Basic_Usage.ipynb) walks you through how to input your compounds, steps, and starting materials. Visualization of routes in PDF is shown.
- [Tutorials/Route_Separation.ipynb](/Tutorials/Route_Separation.ipynb) reproduces the route separation results from the paper.
- [Tutorials/Pharma_Compounds.ipynb](/Tutorials/Pharma_Compounds.ipynb) reproduces the three FDA-approved drug results from the paper.

## Licenses

All code is licensed under MIT License.
All code is licensed under MIT License. The content of the [pre-print on arXiv](https://arxiv.org/abs/2405.13983) is licensed under CC-BY 4.0.

## TODO:
## TODO

- Bring codecov to 80+.
- Bring codecov to 80+.
- Revise [Models/TensorGen.py](/DirectMultiStep/Models/TensorGen.py) so that it can work with batch size greater than 1.
33 changes: 32 additions & 1 deletion download_ckpts.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,35 @@
#!/bin/bash

mkdir -p Data/Checkpoints
curl -o Data/Checkpoints/van_6x3_6x3_final.ckpt https://files.batistalab.com/DirectMultiStep/ckpts/van_6x3_6x3_final.ckpt

read -p "Do you want to download (with SM, 10M) model ckpt? (38 MB) [y/N]: " choice
case "$choice" in
y|Y )
curl -o Data/Checkpoints/sm_6x3_6x3_final.ckpt https://files.batistalab.com/DirectMultiStep/ckpts/sm_6x3_6x3_final.ckpt
;;
* )
echo "Skipping (with SM, 10M) ckpt."
;;
esac

read -p "Do you want to download (with SM, 60M) model ckpt? (228 MB) [y/N]: " choice
case "$choice" in
y|Y )
curl -o Data/Checkpoints/sm_8x4_8x4_final.ckpt https://files.batistalab.com/DirectMultiStep/ckpts/sm_8x4_8x4_final.ckpt
;;
* )
echo "Skipping (with SM, 60M) ckpt."
;;
esac

read -p "Do you want to download (without SM, 60M) model ckpt? (228 MB) [y/N]: " choice
case "$choice" in
y|Y )
curl -o Data/Checkpoints/nosm_8x4_8x4_final.ckpt https://files.batistalab.com/DirectMultiStep/ckpts/nosm_8x4_8x4_final.ckpt
;;
* )
echo "Skipping (without SM, 60M) ckpt."
;;
esac


35 changes: 13 additions & 22 deletions DirectMultiStep/train_nosm.py → train_nosm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,41 @@
import torch
import lightning as L
from pathlib import Path
from Models.Configure import prepare_model, determine_device, VanillaTransformerConfig
from Models.Training import PLTraining
from DirectMultiStep.Models.Configure import prepare_model, determine_device, VanillaTransformerConfig
from DirectMultiStep.Models.Training import PLTraining
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import RichModelSummary
import DirectMultiStep.helpers as helpers

data_path = Path(__file__).resolve().parent / "Data" / "Processed"
train_path = Path(__file__).resolve().parent / "Data" / "Training"
run_name = "van_6x3_6x3_020"
run_name = "nosm_run_name"
batch_size = 32
lr = 3e-4
steps_per_epoch = 30299
max_epochs = 4
L.seed_everything(42)
n_devices = 1
n_devices = 4
torch.set_float32_matmul_precision("high")
dl_kwargs = dict(num_workers=0, pin_memory=True)

van_enc_conf = VanillaTransformerConfig(
input_dim=53,
output_dim=53,
input_max_length=145,
output_max_length=1074 + 1, # 1074 is max length
pad_index=52,
n_layers=12,
ff_mult=4,
attn_bias=False,
ff_activation="gelu",
hid_dim=512,
)
van_dec_conf = VanillaTransformerConfig(
model_10m = dict(n_layers=6, ff_mult=3, hid_dim=256)
model_60m = dict(n_layers=8, ff_mult=4, hid_dim=512)

model_config = VanillaTransformerConfig(
input_dim=53,
output_dim=53,
input_max_length=145,
output_max_length=1074 + 1, # 1074 is max length
output_max_length=1074 + 1,
pad_index=52,
n_layers=12,
ff_mult=4,
attn_bias=False,
ff_activation="gelu",
hid_dim=512,
**model_10m,
# **model_60m,
)

model = prepare_model(enc_config=van_enc_conf, dec_config=van_dec_conf)
# enc and dec configs may be different
model = prepare_model(enc_config=model_config, dec_config=model_config)
if __name__ == "__main__":
# Training hyperparameters
mask_idx, pad_idx = 51, 52
Expand Down
36 changes: 15 additions & 21 deletions DirectMultiStep/train_wsm.py → train_wsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
import torch
import lightning as L
from pathlib import Path
from Models.Configure import prepare_model, determine_device, VanillaTransformerConfig
from Models.Training import PLTraining
from DirectMultiStep.Models.Configure import prepare_model, determine_device, VanillaTransformerConfig
from DirectMultiStep.Models.Training import PLTraining
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import RichModelSummary
import DirectMultiStep.helpers as helpers

data_path = Path(__file__).resolve().parent / "Data" / "Processed"
train_path = Path(__file__).resolve().parent / "Data" / "Training"
run_name = "van_6x3_6x3_010"
run_name = "sm_run_name"
batch_size = 32
lr = 3e-4
steps_per_epoch = 30299
Expand All @@ -41,32 +41,26 @@
torch.set_float32_matmul_precision("high")
dl_kwargs = dict(num_workers=120, pin_memory=True)

van_enc_conf = VanillaTransformerConfig(
input_dim=53,
output_dim=53,
input_max_length=145 + 135,
output_max_length=1074 + 1, # 1074 is max length
pad_index=52,
n_layers=6,
ff_mult=3,
attn_bias=False,
ff_activation="gelu",
hid_dim=256,
)
van_dec_conf = VanillaTransformerConfig(

sm_config = dict()

model_10m = dict(n_layers=6, ff_mult=3, hid_dim=256)
model_60m = dict(n_layers=8, ff_mult=4, hid_dim=512)

model_config = VanillaTransformerConfig(
input_dim=53,
output_dim=53,
input_max_length=145 + 135,
output_max_length=1074 + 1, # 1074 is max length
output_max_length=1074 + 1,
pad_index=52,
n_layers=6,
ff_mult=3,
attn_bias=False,
ff_activation="gelu",
hid_dim=256,
**model_10m,
# **model_60m,
)

model = prepare_model(enc_config=van_enc_conf, dec_config=van_dec_conf)
# enc and dec configs may be different
model = prepare_model(enc_config=model_config, dec_config=model_config)
if __name__ == "__main__":
# Training hyperparameters
mask_idx, pad_idx = 51, 52
Expand Down

0 comments on commit f66e682

Please sign in to comment.