Skip to content

Commit

Permalink
code for figure creation (#3)
Browse files Browse the repository at this point in the history
* FIGURES: port code for score distribution, add type checks

* FIGURES: admet satisfaction throughout AL iterations

* FIGURES: suppress RDKit warnings

* FIGURE: walltime figure

* FIGURES: add ability to change the metric for the indication of the upper boundary of the values for generated smiles (max or percentile)

* FIGURES: chemical similarity analysis for abl inhibitors

* STYLE: extract loader into a separate tool

* FIGURES: similarity analysis of scored distribution

* GIT: stop tracking exports folder

* test

* FIGURES: evolution of distribution analysis

* FIGURES: training curves

* FIGURES: pca analysis

* FIGURES: model memorization

* FIGURES: interactions count analysis

* FIGURES: dataset analysis

* FIGURES: dataset analysis

* FIGURES: all remaining figures code
  • Loading branch information
anmorgunov authored Dec 10, 2023
1 parent 734ebd5 commit a7f055c
Show file tree
Hide file tree
Showing 38 changed files with 7,054 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Python application test with pytest

on:
push:
branches: [ main, sandbox ]
branches: [ main ]

jobs:
build:
Expand Down
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ sandbox.ipynb
/PaperRuns/1_Pretraining/datasets/moses_train.csv.gz
/PaperRuns/1_Pretraining/datasets/moses_test.csv.gz
/PaperRuns/1_Pretraining/model_weights/
/figures/exports/*/html/
# /figures/exports/ligand_distribution/jpg/
# /figures/exports/cluster_distribution/jpg/
/figures/exports/
secret.py
*pkl

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Empty file added .vscode/launch.json
Empty file.
2 changes: 1 addition & 1 deletion ChemSpaceAL/Configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"logP": {
"func": lambda mol: Descriptors.MolLogP(mol),
"lower": -0.4,
"upper": 6.5,
"upper": 6.5, # TODO: specify which Abl binder has log.p 6.3
}, # AdMET Lab recommends [0,3], [-0.4, 5.6] from Ghose
}
# Dictionary containing scores for different protein-ligand interactions
Expand Down
Empty file added figures/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions figures/admet_satisfaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import modules.secret
import modules.Graph as Graph
import modules.FilterPassing as flt_pass
import pickle
import os
from typing import List

GENERATIONS_PATH = modules.secret.PRODUCTION_RUNS_PATH + "2. Generation/smiles/"
EXPORT_PATH = os.path.join(os.getcwd(), "figures", "exports", "admet_satisfaction", "")

n_iters = 5
ignored = {"fChar"}
configs = [
("model7_hnh_admet", "HNH", "ADMET", "softsub"),
("model7_hnh_admetfg", "HNH", "ADMET+FGs", "softsub"),
("model2_hnh", "HNH", "ADMET+FGs", "admetfg_softsub"),
("model7_1iep_admet", "1IEP", "ADMET", "softsub"),
("model7_1iep_admetfg", "1IEP", "ADMET+FGs", "softsub"),
("model2_1iep", "1IEP", "ADMET+FGs", "admetfg_softsub"),
]
rerun_admet = False
max_val = -float("inf")
for prefix, target, filters, channel in configs:
print(prefix, target, filters, channel)
if rerun_admet:
fnames = flt_pass.prepare_generation_fnames(
prefix=prefix,
n_iters=n_iters,
channel=channel,
filters=filters,
target=target,
)
load_generation = flt_pass.prepare_generation_loader(base_path=GENERATIONS_PATH)
traces_lists: List[flt_pass.Trace] = []
filtered_dicts = []
max_val = -float("inf")
for i, fname in enumerate(fnames):
smiles = load_generation(fname)
filtToData = flt_pass.compute_admet_metrics(smiles)
filtered_dicts.append(filtToData)
pickle.dump(
filtered_dicts,
open(EXPORT_PATH + f"{prefix}_{filters}_{target}_dicts.pkl", "wb"),
)
traces, i_max_val = flt_pass.create_admet_metrics_traces(
filtToData,
showlegend=i == 0,
ignored_metrics=ignored,
distribution_upper_percentile=100,
)
max_val = max(max_val, i_max_val)
traces_lists.append(traces)
else:
filtered_dicts = pickle.load(
open(EXPORT_PATH + f"{prefix}_{filters}_{target}_dicts.pkl", "rb")
)
traces_lists = []
for i, filtToData in enumerate(filtered_dicts):
traces, i_max_val = flt_pass.create_admet_metrics_traces(
filtToData,
showlegend=i == 0,
ignored_metrics=ignored,
distribution_upper_percentile=95,
)
max_val = max(max_val, i_max_val)
traces_lists.append(traces)
max_val = 1.116
fig = flt_pass.create_admet_progression_figure(
traces_lists, v_space=0.1, h_space=0.08, y_max=max_val + 0.05
)
graph = Graph.Graph()
graph.update_parameters(dict(width=1000, height=700, annotation_size=28))
graph.style_figure(fig)
fig.update_layout(
showlegend=True,
legend=dict(
x=1.0,
y=0.5,
font=dict(size=18),
),
)
graph.save_figure(
figure=fig, path=EXPORT_PATH, fname=f"{prefix}_{filters}_{target}"
)
# print(max_val)
33 changes: 33 additions & 0 deletions figures/analyze_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import modules.secret
import os
import modules.ClusterAnalysis as ca

SAMPLING_PATH = modules.secret.PRODUCTION_RUNS_PATH + "3. Sampling/kmeans_objects/"
EXPORT_PATH = os.path.join(
os.getcwd(), "figures", "exports", "cluster_distribution", ""
)

import pickle


configs = [
("model7_1iep_admetfg", "1IEP", "ADMET+FGs", "softsub"),
# ("model2_1iep", "1IEP", "ADMET+FGs", "admetfg_softsub"),
# ("model7_1iep_admet", "1IEP", "ADMET", "softsub"),
]
n_iters = 5
from modules.Graph import Graph

for prefix, target, filters, channel in configs:
loader = ca.prepare_kmeans_fnames(filters="ADMET+FGs")
fnames = loader(prefix, n_iters, channel, filters, target)
fig = ca.plot_cluster_size_evolution(
path=SAMPLING_PATH, fnames=fnames, n_rows=2, n_cols=3
)
gr = Graph()
gr.save_figure(
figure=fig,
path=EXPORT_PATH,
fname=f"test_cluster_size_{prefix}_{channel}_{target}_{filters}",
html=True
)
109 changes: 109 additions & 0 deletions figures/analyze_datasets_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import modules.secret
from modules.Graph import Graph
import modules.DatasetAnalysis as dataset_analysis
import modules.AnalyzeDistribution as analyze_dist
import os
from typing import cast, Iterable

PRETRAINING_PATH = modules.secret.PRODUCTION_RUNS_PATH + "1. Pretraining/datasets/"
GENERATIONS_PATH = modules.secret.PRODUCTION_RUNS_PATH + "2. Generation/smiles/"
PCA_PATH = modules.secret.PRODUCTION_RUNS_PATH + "3. Sampling/pca_weights/"
PCA_FNAME = "scaler_pca_combined_processed_freq1000_block133_120"
EXPORT_PATH = os.path.join(os.getcwd(), "figures", "exports", "datasets_analysis", "")


datasets = ["moses", "guacamol", "combined"]
reduction = "PCA"
desc_type = "mix"
train_sample = 10_000
generation_sample = 10_000

training_smiles = [
dataset_analysis.load_training_smiles(
PRETRAINING_PATH, dataset, sample=train_sample
)
for dataset in datasets
]
generated_smiles = [
dataset_analysis.load_generated_smiles(
GENERATIONS_PATH, dataset, sample=generation_sample
)
for dataset in datasets
]

smile_set = set()
for smile_container in [training_smiles, generated_smiles]:
for smiles in smile_container:
smile_set |= set(smiles)
all_smiles = [training_smiles[i] + generated_smiles[i] for i in range(len(datasets))]
# analyze_dist._calculate_descriptors_for_an_array(smiles=list(smile_set), save_path=EXPORT_PATH, save_name="moses_guac_combined", desc_mode=desc_type)
descriptors = analyze_dist.load_descriptors(
load_path=EXPORT_PATH, load_fname="moses_guac_combined"
)
training_projection = [
analyze_dist.project_smiles(smiles, descriptors) for smiles in training_smiles
]
generated_projection = [
analyze_dist.project_smiles(smiles, descriptors) for smiles in generated_smiles
]
training_reduced, generated_reduced = dataset_analysis.reduce_training_and_generations(
training_projection, generated_projection, PCA_PATH, PCA_FNAME
)

# def invert_even_elts(array):
# return [array[1], array[0], array[3], array[2], array[5], array[4]]


traces = dataset_analysis.prepare_scatter_traces(
training_reduced, generated_reduced,
labels=["Generations", "Training Set"],
colorscale=("#240046", "#80ed99"),
trace_opacity=0.5,
marker_size=2,
marker_width=0.1,
)

fig = dataset_analysis.plot_scatter2d_wsubplots(
traces=traces,
subplot_titles=[
"<b>MOSES</b>",
"<b>GuacaMol</b>",
"<b>Combined Dataset</b>",
],
n_rows=1,
n_cols=3,
)

graph = Graph()
graph.update_parameters(dict(
width=1100,
height=400,
xrange=[-18, 50],
yrange=[-18, 23],
xtick_len=4,
ytick_len=4,
xtick_width=1,
ytick_width=1,
axis_title_size=18,
xaxis_title="Principal Component 1 (18.6% variance explained)",
yaxis_title="Principal Component 2<br>(5.7% variance explained)",
show_xzero=True,
show_yzero=True,
annotation_size=20,
))
graph.style_figure(fig, force_annotations=False)
fig.update_layout(legend=dict(
x=0.01,
y=1.0,
xanchor="left",
yanchor="top",
font=dict(size=16),
orientation="h",
), yaxis2_title="", yaxis3_title="")
graph.save_figure(
figure=fig,
path=EXPORT_PATH,
fname=f"{'+'.join(datasets)}_{desc_type}_{reduction}_trainsample{train_sample}_trainsample{generation_sample}",
html=False,
svg=True,
)
Loading

0 comments on commit a7f055c

Please sign in to comment.