Skip to content

Commit

Permalink
Compliance with strict mypy, better style in visualize module, pre-co…
Browse files Browse the repository at this point in the history
…mmits, code quality workflow (#3)

* STYLE: linting with ruff, fixing mypy issues

* GIT: ruff workflow

* STYLE: type annotations for Visualize module, and mypy fixes

* GIT: modify code qual workflow

* GIT: modify workflow pt2

* GIT: modify workflow pt3

* GIT: pre-commits

* GIT: update qual workflow

* GIT: local black was outdated

* GIT: add type stubs for deps

* GIT: type ignore import of lightning

* GIT: lmao at rdkit having null bytes

* GIT: add reqs and update workflow

* GIT: typo in dep version

* DOCS: add ruff badge to readme
  • Loading branch information
anmorgunov authored May 31, 2024
1 parent f66e682 commit c5f8f6d
Show file tree
Hide file tree
Showing 24 changed files with 897 additions and 832 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ name: DirectMultiStep CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest

steps:
Expand Down
35 changes: 35 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Code Quality

on: [push, pull_request]

jobs:
qualitycheck:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install numpy==1.24.3 torch==2.3.0 rdkit==2023.09.3 pyyaml==6.0
pip install pytest pytest-cov lightning==2.2.5 pycairo==1.26.0 cairosvg==2.7.1
pip install mypy black isort ruff types-requests types-tqdm types-PyYAML
- name : Run ruff
run: ruff check
- name: Run black
run: black --check .
- name: Run isort
run: isort --check --profile black .
- name: Run mypy
run: mypy --strict . --exclude=tests

- name: Run tests
run: pytest -v
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Data/PaRoutes/*json
Data/Processed/pre_perms
Data/Processed/*pkl
Data/Training/
Data/Figures/
*pkl

.coverage
Expand Down
25 changes: 25 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
default_language_version:
python: python3.11

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
hooks:
- id: ruff
args: [ --fix ]

- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
types: [python]
args: [--check]

- repo: local
hooks:
- id: isort
name: isort
entry: isort
language: system
types: [python]
args: [--check,--profile=black]
13 changes: 6 additions & 7 deletions Data/download.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
""" Module with script to download public data
"""Module with script to download public data
Adapted from https://github.com/MolecularAI/PaRoutes/blob/main/data/download_data.py
"""

import os
import sys
from pathlib import Path
Expand Down Expand Up @@ -33,8 +34,8 @@
]


def _download_file(url: str, filename: str) -> None:
with requests.get(url, stream=True) as response:
def _download_file(url: str | Path, filename: str | Path) -> None:
with requests.get(str(url), stream=True) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
pbar = tqdm.tqdm(
Expand All @@ -53,13 +54,11 @@ def main() -> None:
path.mkdir(parents=True, exist_ok=True)
for filespec in FILES_TO_DOWNLOAD:
try:
_download_file(
filespec["url"], path / filespec["filename"]
)
_download_file(filespec["url"], path / filespec["filename"])
except requests.HTTPError as err:
print(f"Download failed with message {str(err)}")
sys.exit(1)


if __name__ == "__main__":
main()
main()
93 changes: 21 additions & 72 deletions Data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@

import json
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Set, Union, cast

from tqdm import tqdm
from Utils.PreProcess import (

from DirectMultiStep.Utils.PreProcess import (
FilteredDict,
filter_mol_nodes,
max_tree_depth,
find_leaves,
FilteredDict,
generate_permutations,
max_tree_depth,
)
from pathlib import Path
from typing import List, Tuple, Dict, Union, Set, Optional

data_path = Path(__file__).parent / "PaRoutes"
save_path = Path(__file__).parent / "Processed"
Expand All @@ -43,8 +45,7 @@


class PaRoutesDataset:

def __init__(self, data_path: Path, filename: str, verbose: bool = True):
def __init__(self, data_path: Path, filename: str, verbose: bool = True) -> None:
self.data_path = data_path
self.filename = filename
self.dataset = json.load(open(data_path.joinpath(filename), "r"))
Expand All @@ -53,70 +54,19 @@ def __init__(self, data_path: Path, filename: str, verbose: bool = True):

self.products: List[str] = []
self.filtered_data: FilteredType = []
self.path_strings: List[str] = []
self.max_steps: List[int] = []
self.SMs: List[List[str]] = []
# self.path_strings: List[str] = []
# self.max_steps: List[int] = []
# self.SMs: List[List[str]] = []

self.non_permuted_path_strings: List[str] = []
# self.non_permuted_path_strings: List[str] = []

def filter_dataset(self):
def filter_dataset(self) -> None:
if self.verbose:
print("- Filtering all_routes to remove meta data")
for route in tqdm(self.dataset):
filtered_node = filter_mol_nodes(route)
self.filtered_data.append(filtered_node)
self.products.append(filtered_node["smiles"])

def compress_to_string(self):
if self.verbose:
print(
"- Compressing python dictionaries into python strings and generating permutations"
)

for filtered_route in tqdm(self.filtered_data):
permuted_path_strings = generate_permutations(filtered_route)
# permuted_path_strings = [str(data).replace(" ", "")]
self.path_strings.append(permuted_path_strings)
self.non_permuted_path_strings.append(str(filtered_route).replace(" ", ""))

def find_max_depth(self):
if self.verbose:
print("- Finding the max depth of each route tree")
for filtered_route in tqdm(self.filtered_data):
self.max_steps.append(max_tree_depth(filtered_route))

def find_all_leaves(self):
if self.verbose:
print("- Finding all leaves of each route tree")
for filtered_route in tqdm(self.filtered_data):
self.SMs.append(find_leaves(filtered_route))

def preprocess(self):
self.filter_dataset()
self.compress_to_string()
self.find_max_depth()
self.find_all_leaves()

def prepare_final_datasets(
self, exclude: Optional[Set[int]] = None
) -> Tuple[Dataset, Dataset]:
if exclude is None:
exclude = set()
dataset:Dataset = []
dataset_each_sm:Dataset = []
for i in tqdm(range(len(self.products))):
if i in exclude:
continue
entry:DatasetEntry = {
"train_ID": i,
"product": self.products[i],
"path_strings": self.path_strings[i],
"max_step": self.max_steps[i],
}
dataset.append(entry | {"all_SM": self.SMs[i]})
for sm in self.SMs[i]:
dataset_each_sm.append({**entry, "SM": sm})
return (dataset, dataset_each_sm)
self.products.append(cast(str, filtered_node["smiles"]))

def prepare_final_dataset_v2(
self,
Expand All @@ -138,9 +88,7 @@ def prepare_final_dataset_v2(
for filtered_route in tqdm(self.filtered_data):
non_permuted_string = str(filtered_route).replace(" ", "")
non_permuted_paths.add(non_permuted_string)
permuted_path_strings = generate_permutations(
filtered_route, max_perm=None
)
permuted_path_strings = generate_permutations(filtered_route, max_perm=None)
for permuted_path_string in permuted_path_strings:
if permuted_path_string in exclude_path_strings:
break
Expand All @@ -156,11 +104,11 @@ def prepare_final_dataset_v2(

for path_string in permuted_path_strings:
for sm_count, starting_material in enumerate(all_SMs):
products.append(filtered_route["smiles"])
products.append(cast(str, filtered_route["smiles"]))
starting_materials.append(starting_material)
path_strings.append(path_string)
n_steps_list.append(n_steps)
if n_sms is not None and sm_count+1 >= n_sms:
if n_sms is not None and sm_count + 1 >= n_sms:
break
print(f"Created dataset with {len(products)} entries")
pickle.dump(
Expand All @@ -169,6 +117,7 @@ def prepare_final_dataset_v2(
)
return non_permuted_paths


# ------- Dataset Processing -------
# print("--- Processing of the PaRoutes dataset begins!")
# print("-- starting to process n1 Routes")
Expand Down Expand Up @@ -238,12 +187,12 @@ def prepare_final_dataset_v2(

# ------- Remove SM info from datasets -------

def remove_sm_from_ds(load_path:Path, save_path:Path):

def remove_sm_from_ds(load_path: Path, save_path: Path) -> None:
products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb"))
pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb"))



# remove_sm_from_ds(load_path=save_path / "all_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "all_dataset_nperms=1_nosm.pkl")
# remove_sm_from_ds(load_path=save_path / "n1_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "n1_dataset_nperms=1_nosm.pkl")
# remove_sm_from_ds(load_path=save_path / "n5_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "n5_dataset_nperms=1_nosm.pkl")
# remove_sm_from_ds(load_path=save_path / "n5_dataset_nperms=1_nsms=1.pkl", save_path=save_path / "n5_dataset_nperms=1_nosm.pkl")
Loading

0 comments on commit c5f8f6d

Please sign in to comment.