From ad5199966b0e43fbbcfafa61b7bd04378ef239c5 Mon Sep 17 00:00:00 2001 From: Kobi Felton Date: Fri, 9 Jun 2023 10:40:33 +0100 Subject: [PATCH] Run paper table (#135) * Update make file * Add analysis notebook * Enable resuming and skipping training * Update performance plotting * Add scripts * Make final tables * Add dataset version * Upgrade dataset version * Add dataset version to cli and update makefile --- Makefile | 54 +- .../condition_prediction/run.py | 114 ++- .../condition_prediction/utils.py | 39 +- notebooks/clean_wandb.ipynb | 105 +++ notebooks/plot_model_performance_wandb.ipynb | 799 ++++++++++++++++++ notebooks/update_wandb_runs.py | 83 ++ 6 files changed, 1136 insertions(+), 58 deletions(-) create mode 100644 notebooks/clean_wandb.ipynb create mode 100644 notebooks/plot_model_performance_wandb.ipynb create mode 100644 notebooks/update_wandb_runs.py diff --git a/Makefile b/Makefile index c72ca7ab..ae598fec 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ clean_default_num_agent=3 clean_default_num_cat=1 clean_default_num_reag=2 WANDB_ENTITY=ceb-sre +dataset_version=v5 mypy_orderly: @@ -165,7 +166,7 @@ run_python_310: train_model: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=True --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=True --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 @@ -242,40 +243,40 @@ paper_5 : paper_plot_uspto_no_trust_filtered_min_frequency_of_occurrence_10_100 # 6. clean (final) paper_gen_uspto_no_trust_no_map: #requires: paper_extract_uspto_no_trust - python -m orderly.clean --output_path="data/orderly/datasets/orderly_no_trust_no_map.parquet" --ord_extraction_path="data/orderly/uspto_no_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_no_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=False --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=3 --num_cat=0 --num_reag=0 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 + python -m orderly.clean --output_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map.parquet" --ord_extraction_path="data/orderly/uspto_no_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_no_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=False --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=3 --num_cat=0 --num_reag=0 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 paper_gen_uspto_no_trust_with_map: #requires: paper_extract_uspto_no_trust - python -m orderly.clean --output_path="data/orderly/datasets/orderly_no_trust_with_map.parquet" --ord_extraction_path="data/orderly/uspto_no_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_no_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=True --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=3 --num_cat=0 --num_reag=0 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 + python -m orderly.clean --output_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map.parquet" --ord_extraction_path="data/orderly/uspto_no_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_no_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=True --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=3 --num_cat=0 --num_reag=0 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 paper_gen_uspto_with_trust_with_map: #requires: paper_extract_uspto_with_trust - python -m orderly.clean --output_path="data/orderly/datasets/orderly_with_trust_with_map.parquet" --ord_extraction_path="data/orderly/uspto_with_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_with_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=True --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=0 --num_cat=1 --num_reag=2 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 + python -m orderly.clean --output_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map.parquet" --ord_extraction_path="data/orderly/uspto_with_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_with_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=True --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=0 --num_cat=1 --num_reag=2 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 paper_gen_uspto_with_trust_no_map: #requires: paper_extract_uspto_with_trust - python -m orderly.clean --output_path="data/orderly/datasets/orderly_with_trust_no_map.parquet" --ord_extraction_path="data/orderly/uspto_with_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_with_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=False --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=0 --num_cat=1 --num_reag=2 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 + python -m orderly.clean --output_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map.parquet" --ord_extraction_path="data/orderly/uspto_with_trust/extracted_ords" --molecules_to_remove_path="data/orderly/uspto_with_trust/all_molecule_names.csv" --min_frequency_of_occurrence=100 --map_rare_molecules_to_other=False --set_unresolved_names_to_none_if_mapped_rxn_str_exists_else_del_rxn=True --remove_rxn_with_unresolved_names=False --set_unresolved_names_to_none=False --num_product=1 --num_reactant=2 --num_solv=2 --num_agent=0 --num_cat=1 --num_reag=2 --consistent_yield=True --scramble=True --train_test_split_fraction=0.9 paper_6: paper_gen_uspto_no_trust_no_map paper_gen_uspto_no_trust_with_map paper_gen_uspto_with_trust_with_map paper_gen_uspto_with_trust_no_map # 7. gen fp fp_no_trust_no_map_test: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_no_trust_no_map_test.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_test.parquet" --fp_size=2048 --overwrite=False fp_no_trust_no_map_train: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_no_trust_no_map_train.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_train.parquet" --fp_size=2048 --overwrite=False fp_no_trust_with_map_test: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_no_trust_with_map_test.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_test.parquet" --fp_size=2048 --overwrite=False fp_no_trust_with_map_train: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_no_trust_with_map_train.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_train.parquet" --fp_size=2048 --overwrite=False fp_with_trust_with_map_test: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_with_trust_with_map_test.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_test.parquet" --fp_size=2048 --overwrite=False fp_with_trust_with_map_train: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_with_trust_with_map_train.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_train.parquet" --fp_size=2048 --overwrite=False fp_with_trust_no_map_test: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_with_trust_no_map_test.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_test.parquet" --fp_size=2048 --overwrite=False fp_with_trust_no_map_train: - python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets/orderly_with_trust_no_map_train.parquet" --fp_size=2048 --overwrite=False + python -m orderly.gen_fp --clean_data_folder_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_train.parquet" --fp_size=2048 --overwrite=False paper_7: fp_no_trust_no_map_test fp_no_trust_no_map_train fp_no_trust_with_map_test fp_no_trust_with_map_train fp_with_trust_with_map_test fp_with_trust_with_map_train fp_with_trust_no_map_test fp_with_trust_no_map_train @@ -288,36 +289,39 @@ paper_gen_all: paper_1 paper_2 paper_3 paper_4 paper_5 paper_6 paper_7 #Remember to switch env here (must contain TF, e.g. tf_mac_m1) # Full dataset no_trust_no_map_train: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_no_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_no_trust_no_map_test.parquet" --output_folder_path="models/no_trust_no_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_test.parquet" --output_folder_path="models/no_trust_no_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) no_trust_with_map_train: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) with_trust_no_map_train: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_with_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_with_trust_no_map_test.parquet" --output_folder_path="models/with_trust_no_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_test.parquet" --output_folder_path="models/with_trust_no_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) with_trust_with_map_train: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_with_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_with_trust_with_map_test.parquet" --output_folder_path="models/with_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_test.parquet" --output_folder_path="models/with_trust_with_map" --train_fraction=1 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) # 20% of data no_trust_no_map_train_20: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_no_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_no_trust_no_map_test.parquet" --output_folder_path="models/no_trust_no_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_no_map_test.parquet" --output_folder_path="models/no_trust_no_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) no_trust_with_map_train_20: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_no_trust_with_map_test.parquet" --output_folder_path="models/no_trust_with_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) with_trust_no_map_train_20: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_with_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_with_trust_no_map_test.parquet" --output_folder_path="models/with_trust_no_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_no_map_test.parquet" --output_folder_path="models/with_trust_no_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=$(WANDB_ENTITY) --dataset_version=$(datset_version) with_trust_with_map_train_20: - python -m condition_prediction --train_data_path="data/orderly/datasets/orderly_with_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets/orderly_with_trust_with_map_test.parquet" --output_folder_path="models/with_trust_with_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY + python -m condition_prediction --train_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_train.parquet" --test_data_path="data/orderly/datasets_$(dataset_version)/orderly_with_trust_with_map_test.parquet" --output_folder_path="models/with_trust_with_map_20" --train_fraction=0.2 --train_val_split=0.8 --overwrite=False --epochs=20 --evaluate_on_test_data=True --early_stopping_patience=5 --wandb_entity=WANDB_ENTITY # Sweeps RANDOM_SEEDS = 12345 54321 98765 -TRAIN_FRACS = 0.2 0.4 0.6 0.8 1.0 -DATASETS_PATH = /project/studios/orderly-preprocessing/ORDerly/data/orderly/datasets/ -DATASETS = no_trust_no_map no_trust_with_map with_trust_no_map with_trust_with_map +TRAIN_FRACS = 1.0 0.2 0.4 0.6 0.8 +# Path on lightning +# DATASETS_PATH = /project/studios/orderly-preprocessing/ORDerly/data/orderly/datasets_$(dataset_version)/ +# Normal path +DATASETS_PATH = ORDerly/data/orderly/datasets_$(dataset_version)/ +DATASETS = no_trust_with_map no_trust_no_map with_trust_with_map with_trust_no_map dataset_size_sweep: @for random_seed in ${RANDOM_SEEDS}; \ do \ @@ -325,7 +329,7 @@ dataset_size_sweep: do \ for train_frac in ${TRAIN_FRACS}; \ do \ - rm -rf .tf_cache* && python -m condition_prediction --train_data_path=${DATASETS_PATH}/orderly_$${dataset}_train.parquet --test_data_path=${DATASETS_PATH}/orderly_$${dataset}_test.parquet --output_folder_path=models/$${dataset} --train_fraction=$${train_frac} --train_val_split=0.8 --random_seed=$${random_seed} --overwrite=True --batch_size=512 --epochs=100 --early_stopping_patience=0 --evaluate_on_test_data=True --wandb_entity=$(WANDB_ENTITY); \ + rm -rf .tf_cache* && python -m condition_prediction --train_data_path=${DATASETS_PATH}/orderly_$${dataset}_train.parquet --test_data_path=${DATASETS_PATH}/orderly_$${dataset}_test.parquet --output_folder_path=models/$${dataset} --dataset_version=$(datset_version) --train_fraction=$${train_frac} --train_val_split=0.8 --random_seed=$${random_seed} --overwrite=True --batch_size=512 --epochs=100 --train_mode=0 --early_stopping_patience=0 --evaluate_on_test_data=True --wandb_entity=$(WANDB_ENTITY) ; \ done \ done \ done diff --git a/condition_prediction/condition_prediction/run.py b/condition_prediction/condition_prediction/run.py index 82b04532..5cd98b1f 100644 --- a/condition_prediction/condition_prediction/run.py +++ b/condition_prediction/condition_prediction/run.py @@ -11,16 +11,15 @@ LOG = logging.getLogger(__name__) -from functools import partial import numpy as np import pandas as pd import tensorflow as tf +import wandb from click_loglevel import LogLevel from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint -import wandb from condition_prediction.constants import HARD_SELECTION, SOFT_SELECTION, TEACHER_FORCE from condition_prediction.data_generator import ( get_datasets, @@ -34,6 +33,7 @@ ) from condition_prediction.utils import ( TrainingMetrics, + download_model_from_wandb, frequency_informed_accuracy, get_grouped_scores, get_grouped_scores_top_n, @@ -93,8 +93,13 @@ class ConditionPrediction: wandb_entity: Optional[str] = None wandb_tags: Optional[List[str]] = None wandb_group: Optional[str] = None + wandb_run_id: Optional[str] = None + resume: bool = False + resume_from_best: bool = False verbosity: int = 2 random_seed: int = 12345 + skip_training: bool = False + dataset_version: str = "v5" def __post_init__(self) -> None: pass @@ -131,6 +136,7 @@ def run_model_arguments(self) -> None: random_seed=self.random_seed, epochs=self.epochs, train_mode=self.train_mode, + skip_training=self.skip_training, fp_size=self.fp_size, dropout=self.dropout, hidden_size_1=self.hidden_size_1, @@ -154,7 +160,11 @@ def run_model_arguments(self) -> None: wandb_logging=self.wandb_logging, wandb_tags=self.wandb_tags, wandb_group=self.wandb_group, + wandb_run_id=self.wandb_run_id, + resume=self.resume, + resume_from_best=self.resume_from_best, verbosity=self.verbosity, + dataset_version=self.dataset_version, ) @staticmethod @@ -269,38 +279,38 @@ def evaluate_model(model, dataset, encoders): solvent_scores_top1 = get_grouped_scores( ground_truth[:2], predictions[:2], encoders[:2] ) - metrics["test_solvent_accuracy_top1"] = np.mean(solvent_scores_top1) + metrics["solvent_accuracy_top1"] = np.mean(solvent_scores_top1) # 3 agents scores agent_scores_top1 = get_grouped_scores( ground_truth[2:], predictions[2:], encoders[2:] ) - metrics["test_three_agents_accuracy_top1"] = np.mean(agent_scores_top1) + metrics["three_agents_accuracy_top1"] = np.mean(agent_scores_top1) # Overall scores overall_scores_top1 = np.stack( [solvent_scores_top1, agent_scores_top1], axis=1 ).all(axis=1) - metrics["test_overall_accuracy_top1"] = np.mean(overall_scores_top1) + metrics["overall_accuracy_top1"] = np.mean(overall_scores_top1) # Top 3 accuracies # Solvent score solvent_scores_top3 = get_grouped_scores_top_n( ground_truth[:2], predictions[:2], encoders[:2], 3 ) - metrics["test_solvent_accuracy_top3"] = np.mean(solvent_scores_top3) + metrics["solvent_accuracy_top3"] = np.mean(solvent_scores_top3) # 3 agents scores agent_scores_top3 = get_grouped_scores_top_n( ground_truth[2:], predictions[2:], encoders[2:], 3 ) - metrics["test_three_agents_accuracy_top3"] = np.mean(agent_scores_top3) + metrics["three_agents_accuracy_top3"] = np.mean(agent_scores_top3) # Overall scores overall_scores_top3 = np.stack( [solvent_scores_top3, agent_scores_top3], axis=1 ).all(axis=1) - metrics["test_overall_accuracy_top3"] = np.mean(overall_scores_top3) + metrics["overall_accuracy_top3"] = np.mean(overall_scores_top3) return metrics @@ -340,7 +350,11 @@ def run_model( wandb_tags: Optional[List[str]] = None, wandb_group: Optional[str] = None, verbosity: int = 2, - dataset_version: str = "v4", + dataset_version: str = "v5", + skip_training: bool = False, + wandb_run_id: Optional[str] = None, + resume: bool = False, + resume_from_best: bool = False, ) -> None: """ Run condition prediction training @@ -563,6 +577,8 @@ def run_model( tags=wandb_tags, group=wandb_group, config=config, + id=wandb_run_id if resume else None, + resume="allow" if resume else None, # sync_tensorboard=True, ) callbacks.extend( @@ -580,31 +596,58 @@ def run_model( save_weights_only=True, ) ) - - use_multiprocessing = True if workers > 0 else False - h = None last_checkpoint_filepath = output_folder_path / "weights.last.hdf5" - try: - h = model.fit( - train_dataset, - epochs=epochs, - verbose=verbosity, - validation_data=val_dataset_for_train, - callbacks=callbacks, - use_multiprocessing=use_multiprocessing, - workers=workers, + if resume and wandb_run_id is not None: + # Download the model weights from wandb + api = wandb.Api() + run = api.run(f"{wandb_entity}/{wandb_project}/{wandb_run_id}") + + # Download best model + download_model_from_wandb( + run, + alias="best", + root=output_folder_path, ) - except KeyboardInterrupt: - LOG.info( - "Keyboard interrupt detected. Stopping training and doing evaluation." + download_model_from_wandb( + run, + alias="last_epoch", + root=output_folder_path, ) - finally: - model.save_weights(last_checkpoint_filepath) + + # Update weights + if resume_from_best: + model.load_weights(best_checkpoint_filepath) + else: + model.load_weights(last_checkpoint_filepath) update_teacher_forcing_model_weights( update_model=pred_model, to_copy_model=model ) - # Upload the best and last model model - if wandb_logging: + + use_multiprocessing = True if workers > 0 else False + h = None + if not skip_training: + try: + h = model.fit( + train_dataset, + epochs=epochs, + verbose=verbosity, + validation_data=val_dataset_for_train, + callbacks=callbacks, + use_multiprocessing=use_multiprocessing, + workers=workers, + ) + except KeyboardInterrupt: + LOG.info( + "Keyboard interrupt detected. Stopping training and doing evaluation." + ) + finally: + model.save_weights(last_checkpoint_filepath) + update_teacher_forcing_model_weights( + update_model=pred_model, to_copy_model=model + ) + + # Upload the best and last model + if wandb_logging and not skip_training: # Save and upload last model artifact = wandb.Artifact( # type: ignore name=f"run_{wandb_run.id}_model", # type: ignore @@ -623,6 +666,10 @@ def run_model( # Train and val metrics train_val_metrics_dict = {} + model.load_weights(last_checkpoint_filepath) + update_teacher_forcing_model_weights( + update_model=pred_model, to_copy_model=model + ) train_val_metrics_dict["trust_labelling"] = trust_labelling train_val_metrics_dict.update( { @@ -697,7 +744,6 @@ def run_model( test_metrics_file_path = output_folder_path / "test_metrics.json" with open(test_metrics_file_path, "w") as file: json.dump(jsonify_dict(test_metrics_dict), file) - if wandb_logging: # Log data artifact = wandb.Artifact( # type: ignore @@ -757,6 +803,12 @@ def run_model( type=int, help="Training mode. 0 for Teacher force, 1 for hard seleciton, 2 for soft selection", ) +@click.option( + "--dataset_version", + default="v5", + type=str, + help="The version of the dataset", +) @click.option( "--epochs", default=20, @@ -932,6 +984,7 @@ def main_click( output_folder_path: pathlib.Path, train_fraction: float, train_val_split: float, + dataset_version: str, random_seed: int, epochs: int, train_mode: int, @@ -989,6 +1042,7 @@ def main_click( train_fraction=train_fraction, train_val_split=train_val_split, random_seed=random_seed, + dataset_version=dataset_version, epochs=epochs, train_mode=train_mode, early_stopping_patience=early_stopping_patience, @@ -1028,6 +1082,7 @@ def main( output_folder_path: pathlib.Path, train_fraction: float, train_val_split: float, + dataset_version: str, epochs: int, random_seed: int, train_mode: int, @@ -1134,6 +1189,7 @@ def main( output_folder_path=output_folder_path, train_fraction=train_fraction, train_val_split=train_val_split, + dataset_version=dataset_version, random_seed=random_seed, generate_fingerprints=generate_fingerprints, fp_size=fp_size, diff --git a/condition_prediction/condition_prediction/utils.py b/condition_prediction/condition_prediction/utils.py index 4003f902..7220189f 100644 --- a/condition_prediction/condition_prediction/utils.py +++ b/condition_prediction/condition_prediction/utils.py @@ -1,18 +1,19 @@ +import math import os import socket from collections import Counter from copy import deepcopy from datetime import datetime from datetime import datetime as dt -from typing import List +from itertools import product +from pathlib import Path +from typing import List, Tuple import matplotlib.pyplot as plt import numpy as np from keras import callbacks from sklearn.preprocessing import OneHotEncoder - -from itertools import product -import math +from wandb.sdk.wandb_run import Run def log_dir(prefix="", comment=""): @@ -392,3 +393,33 @@ def listtonumpy(a, copy=True): if transform_all: a = np.array(a) return a + + +def download_model_from_wandb( + run: Run, alias="best", weights_file="weights.best.hdf5", root=None +) -> Tuple[Path, str]: + """Download best model from wandb + + Arguments + ---------- + run: wandb.Run + + + Returns + ------- + Path to best model checkpoint + + """ + # Get best checkpoint + artifacts = run.logged_artifacts() # type: ignore + ckpt_path = None + artifact_name = None + for artifact in artifacts: + if artifact.type == "model" and alias in artifact.aliases: + ckpt_path = artifact.download(root=root) + artifact_name = artifact.name + if ckpt_path is None or artifact_name is None: + raise ValueError("No checkpoint found with alias: {}".format(alias)) + ckpt_path = Path(ckpt_path) / weights_file + + return ckpt_path, artifact_name diff --git a/notebooks/clean_wandb.ipynb b/notebooks/clean_wandb.ipynb new file mode 100644 index 00000000..a29762a3 --- /dev/null +++ b/notebooks/clean_wandb.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "from tqdm import tqdm\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "api = wandb.Api()\n", + "wandb_entity=\"ceb-sre\"\n", + "wandb_project=\"orderly\"\n", + "runs = api.runs(\n", + " f\"{wandb_entity}/{wandb_project}\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 7/7 [00:21<00:00, 3.06s/it]\n" + ] + } + ], + "source": [ + "for run in tqdm(runs):\n", + " artifacts = run.logged_artifacts()\n", + " for artifact in artifacts:\n", + " if artifact.type == \"model\":\n", + " artifact.delete(delete_aliases=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3476/3476 [00:52<00:00, 65.98it/s]\n" + ] + } + ], + "source": [ + "start_date = datetime(2023, 4, 1)\n", + "date_format = \"%Y-%m-%dT%H:%M:%S\"\n", + "for run in tqdm(runs):\n", + " date_str = runs[0].createdAt\n", + " date_obj = datetime.strptime(date_str, date_format)\n", + " artifacts = run.logged_artifacts()\n", + " if date_obj < start_date:\n", + " for artifact in artifacts:\n", + " if artifact.type == \"model\":\n", + " artifact.delete(delete_aliases=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/plot_model_performance_wandb.ipynb b/notebooks/plot_model_performance_wandb.ipynb new file mode 100644 index 00000000..f5d04b57 --- /dev/null +++ b/notebooks/plot_model_performance_wandb.ipynb @@ -0,0 +1,799 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-06-07 14:34:45.262124: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2023-06-07 14:34:45.985455: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "2023-06-07 14:34:49.711563: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", + "2023-06-07 14:34:49.734714: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", + "2023-06-07 14:34:49.735672: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" + ] + } + ], + "source": [ + "from condition_prediction.run import ConditionPrediction\n", + "import wandb\n", + "from tqdm import tqdm\n", + "from datetime import datetime\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pathlib\n", + "from tqdm import tqdm\n", + "import gc" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "api = wandb.Api()\n", + "wandb_entity=\"ceb-sre\"\n", + "wandb_project=\"orderly\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prefetch buffer size: -1\n", + "Prefetch buffer size: -1\n", + "Prefetch buffer size: -1\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/zeus/content/wandb/run-20230607_021600-zbx7fqj5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Resuming run fast-sound-304 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/ceb-sre/orderly" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/ceb-sre/orderly/runs/zbx7fqj5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact run_zbx7fqj5_model:v1, 67.03MB. 1 files... \n", + "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", + "Done. 0:0:1.0\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact run_zbx7fqj5_model:v0, 67.03MB. 1 files... \n", + "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", + "Done. 0:0:1.2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "139/139 [==============================] - 2s 10ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-06-07 02:16:06.394389: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "139/139 [==============================] - 1s 10ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-06-07 02:17:33.114064: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "73/73 [==============================] - 2s 16ms/step - loss: 14.3971 - mol1_loss: 2.5769 - mol2_loss: 2.3048 - mol3_loss: 3.0256 - mol4_loss: 4.0312 - mol5_loss: 2.4585 - mol1_acc: 0.2945 - mol1_top3: 0.5576 - mol1_top5: 0.6933 - mol2_acc: 0.5355 - mol2_top3: 0.6867 - mol2_top5: 0.7460 - mol3_acc: 0.2804 - mol3_top3: 0.5287 - mol3_top5: 0.6506 - mol4_acc: 0.3250 - mol4_top3: 0.4592 - mol4_top5: 0.5393 - mol5_acc: 0.7476 - mol5_top3: 0.8141 - mol5_top5: 0.8418\n", + "73/73 [==============================] - 1s 10ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-06-07 02:19:00.864700: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "73/73 [==============================] - 1s 11ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-06-07 02:19:40.395855: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization\n" + ] + }, + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... (success)." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run summary:


epoch/epoch99
epoch/learning_rate0.01
epoch/loss1.09914
epoch/mol1_acc0.91118
epoch/mol1_loss0.2798
epoch/mol1_top30.98683
epoch/mol1_top50.99557
epoch/mol2_acc0.93469
epoch/mol2_loss0.20221
epoch/mol2_top30.99249
epoch/mol2_top50.99727
epoch/mol3_acc0.9036
epoch/mol3_loss0.31104
epoch/mol3_top30.985
epoch/mol3_top50.99469
epoch/mol4_acc0.93643
epoch/mol4_loss0.20053
epoch/mol4_top30.99228
epoch/mol4_top50.9974
epoch/mol5_acc0.96917
epoch/mol5_loss0.10558
epoch/mol5_top30.99613
epoch/mol5_top50.9987
epoch/time_per_step0.04884
epoch/training_throughput10482.20503
epoch/val_loss22.84926
epoch/val_mol1_acc0.26352
epoch/val_mol1_loss6.14518
epoch/val_mol1_top30.48875
epoch/val_mol1_top50.61298
epoch/val_mol2_acc0.46446
epoch/val_mol2_loss4.19288
epoch/val_mol2_top30.70539
epoch/val_mol2_top50.79343
epoch/val_mol3_acc0.27778
epoch/val_mol3_loss6.79353
epoch/val_mol3_top30.4941
epoch/val_mol3_top50.5984
epoch/val_mol4_acc0.54362
epoch/val_mol4_loss4.00107
epoch/val_mol4_top30.74025
epoch/val_mol4_top50.80768
epoch/val_mol5_acc0.81986
epoch/val_mol5_loss1.71659
epoch/val_mol5_top30.92357
epoch/val_mol5_top50.94678
frequency_informed_agent_accuracy0.08395
frequency_informed_agent_accuracy_top_10.08395
frequency_informed_agent_accuracy_top_30.18618
frequency_informed_overall_accuracy0.01638
frequency_informed_overall_accuracy_top_10.01638
frequency_informed_overall_accuracy_top_30.03929
frequency_informed_solvent_accuracy0.08808
frequency_informed_solvent_accuracy_top_10.08808
frequency_informed_solvent_accuracy_top_30.22639
loss14.38841
mol1_acc0.29105
mol1_loss2.57782
mol1_top30.55775
mol1_top50.69565
mol2_acc0.53309
mol2_loss2.30201
mol2_top30.68826
mol2_top50.74633
mol3_acc0.28063
mol3_loss3.02522
mol3_top30.52925
mol3_top50.65029
mol4_acc0.3285
mol4_loss4.02904
mol4_top30.45636
mol4_top50.53967
mol5_acc0.74852
mol5_loss2.45431
mol5_top30.81477
mol5_top50.84318
trust_labellingFalse

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run fast-sound-304 at: https://wandb.ai/ceb-sre/orderly/runs/zbx7fqj5
Synced 3 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20230607_021600-zbx7fqj5/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "BASE_PATH = pathlib.Path(\"/project/studios/orderly-preprocessing/ORDerly/\")\n", + "DATASETS_PATH = BASE_PATH / \"data/orderly/datasets/\"\n", + "MODEL_PATH = pathlib.Path(\"ORDerly/models/\")\n", + "dataset = \"no_trust_with_map\"\n", + "filters = {\n", + " # \"state\": \"crashed\",\n", + " \"config.output_folder_path\": {\n", + " \"$in\": [\n", + " f\"models/{dataset}\",\n", + " str(MODEL_PATH / dataset),\n", + " f\"/Users/Kobi/Documents/Research/phd_code/ORDerly/models/{dataset}\",\n", + " ],\n", + " },\n", + " \"config.random_seed\": 12345,\n", + " \"config.train_fraction\": 0.2,\n", + " \"config.dataset_version\": \"v4\",\n", + " \"config.train_mode\": 0, # Teacher forcing\n", + "}\n", + "# filters = {\n", + "# \"id\": {\"$in\": [\"zl4inibc\", \"mdwxixa4\", \"zbx7fqj5\"]}\n", + "# }\n", + "runs = api.runs(f\"{wandb_entity}/{wandb_project}\", filters=filters)\n", + "# if not len(runs) == 5: # For 5 training fractions\n", + "# raise ValueError(f\"Not 5 runs for {dataset} (found {len(runs)}, seed {random_seed})\")\n", + "\n", + "for run in runs:\n", + " config = dict(run.config)\n", + " train_data_path = pathlib.Path(\n", + " f\"{DATASETS_PATH}/orderly_{dataset}_train.parquet\"\n", + " )\n", + " test_data_path = pathlib.Path(\n", + " f\"{DATASETS_PATH}/orderly_{dataset}_test.parquet\"\n", + " )\n", + " fp_directory = train_data_path.parent / \"fingerprints\"\n", + " train_fp_path = fp_directory / (train_data_path.stem + \".npy\")\n", + " test_fp_path = fp_directory / (test_data_path.stem + \".npy\")\n", + " output_folder_path = MODEL_PATH / dataset\n", + " output_folder_path.mkdir(parents=True, exist_ok=True)\n", + " tags = dataset.split(\"_\")\n", + " tags = [f\"{tags[0]}_{tags[1]}\", f\"{tags[2]}_{tags[3]}\"]\n", + " config.update(\n", + " {\n", + " \"train_data_path\": train_data_path,\n", + " \"test_data_path\": test_data_path,\n", + " \"train_fp_path\": train_fp_path,\n", + " \"test_fp_path\": test_fp_path,\n", + " \"output_folder_path\": output_folder_path,\n", + " \"skip_training\": True,\n", + " \"resume\": True,\n", + " \"resume_from_best\": True,\n", + " \"generate_fingerprints\": False,\n", + " \"wandb_run_id\": run.id,\n", + " \"wandb_tags\": tags,\n", + " }\n", + " )\n", + " del config[\"n_val\"]\n", + " del config[\"n_test\"]\n", + " del config[\"n_train\"]\n", + " del config[\"dataset_version\"]\n", + " instance = ConditionPrediction(**config)\n", + " instance.run_model_arguments()\n", + " wandb.finish()\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(runs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate Table for dataset v4" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Solvents & 47 // 58 // \\textcolor{lessgreen}{21\\%} & 50 // 61 // \\textcolor{lessgreen}{22\\%} & 23 // 42 // \\textcolor{lessgreen}{26\\%} & 24 // 45 // \\textcolor{lessgreen}{28\\%} \\\\ \n", + "Agents & 54 // 70 // \\textcolor{lessgreen}{35\\%} & 58 // 72 // \\textcolor{lessgreen}{32\\%} & 19 // 39 // \\textcolor{lessgreen}{25\\%} & 21 // 42 // \\textcolor{lessgreen}{27\\%} \\\\ \n", + "Solvents \\& Agents & 31 // 44 // \\textcolor{lessgreen}{19\\%} & 33 // 47 // \\textcolor{lessgreen}{21\\%} & 4 // 21 // \\textcolor{lessgreen}{18\\%} & 5 // 24 // \\textcolor{lessgreen}{21\\%} \\\\\n" + ] + } + ], + "source": [ + "DATASETS = [\"with_trust_with_map\",\"with_trust_no_map\",\"no_trust_with_map\", \"no_trust_no_map\"]\n", + "lines = [\"Solvents\", \"Agents\", \"Solvents \\& Agents\"]\n", + "top_n = 3\n", + "for dataset in DATASETS:\n", + " filters = {\n", + " \"state\": \"finished\",\n", + " \"config.output_folder_path\": f\"models/{dataset}\",\n", + " \"config.random_seed\": 54321,\n", + " \"config.train_fraction\": 1.0,\n", + " \"config.train_mode\": 0, # Teacher forcing\n", + " \"config.dataset_version\": {\"$in\": [\"v4\", \"v5\"]}\n", + " }\n", + " runs = api.runs(\n", + " f\"{wandb_entity}/{wandb_project}\",\n", + " filters=filters\n", + " )\n", + " run = runs[0]\n", + " if len(runs)>0:\n", + " for r in runs:\n", + " if run.config[\"dataset_version\"] == \"v5\":\n", + " run = r\n", + " break\n", + "\n", + " # Get model solvent, agent and overall accuracy\n", + " test_best = run.summary[\"test_best\"]\n", + " solvent_accuracy = test_best[f\"solvent_accuracy_top{top_n}\"]\n", + " agent_accuracy = test_best[f\"three_agents_accuracy_top{top_n}\"]\n", + " overall_accuracy = test_best[f\"overall_accuracy_top{top_n}\"]\n", + "\n", + " # Get frequency informed solvent, agent and overall accuracy\n", + " fi_solvent_accuracy = run.summary[f\"frequency_informed_solvent_accuracy_top_{top_n}\"]\n", + " fi_agent_accuracy = run.summary[f\"frequency_informed_agent_accuracy_top_{top_n}\"]\n", + " fi_overall_accuracy = run.summary[f\"frequency_informed_overall_accuracy_top_{top_n}\"]\n", + "\n", + " # Improvement\n", + " solvent_improvement = (solvent_accuracy-fi_solvent_accuracy)/(1-fi_solvent_accuracy)\n", + " solvent_improvement_color = \"lessgreen\" if solvent_improvement>0 else \"red\"\n", + " agent_improvement = (agent_accuracy-fi_agent_accuracy)/(1-fi_agent_accuracy)\n", + " agent_improvement_color = \"lessgreen\" if agent_improvement>0 else \"red\"\n", + " overall_improvement = (overall_accuracy-fi_overall_accuracy)/(1-fi_overall_accuracy)\n", + " overall_improvement_color = \"lessgreen\" if overall_improvement>0 else \"red\"\n", + "\n", + " # Create table lines\n", + " lines[0] += f\" & {fi_solvent_accuracy*100:.0f} // {solvent_accuracy*100:.0f} // \\\\textcolor{{{solvent_improvement_color}}}{{{solvent_improvement*100:.0f}\\%}} \"\n", + " lines[1] += f\" & {fi_agent_accuracy*100:.0f} // {agent_accuracy*100:.0f} // \\\\textcolor{{{agent_improvement_color}}}{{{agent_improvement*100:.0f}\\%}} \"\n", + " lines[2] += f\" & {fi_overall_accuracy*100:.0f} // {overall_accuracy*100:.0f} // \\\\textcolor{{{overall_improvement_color}}}{{{overall_improvement*100:.0f}\\%}} \"\n", + "print(\"\\\\\\\\ \\n\".join(lines) + \"\\\\\\\\\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "DATASETS = [\"no_trust_no_map\", \"no_trust_with_map\"]\n", + "LABELS = {\n", + " \"with_trust_with_map\": r\"Labelling, rare $\\rightarrow$ other\",\n", + " \"with_trust_no_map\": r\"Labelling, rare $\\rightarrow$ delete rxn\",\n", + " \"no_trust_with_map\": r\"Reaction string, rare $\\rightarrow$ other\",\n", + " \"no_trust_no_map\": r\"Reaction string, rare $\\rightarrow$ delete rxn\",\n", + "}\n", + "TRAIN_FRACS = [0.2, 0.4, 0.6, 0.8, 1.0]\n", + "fig, ax = plt.subplots(1)\n", + "markers = [\"o\", \"d\", \"s\", \"^\"]\n", + "top_n = 3\n", + "colors = {\n", + " \"no_trust_with_map\": \"#5C4682\",\n", + " \"no_trust_no_map\": \"#5e813f\",\n", + "}\n", + "for i, dataset in enumerate(DATASETS):\n", + " overall_accuracies = []\n", + " n_train = []\n", + " for train_fraction in TRAIN_FRACS:\n", + " filters = {\n", + " \"state\": \"finished\",\n", + " \"config.output_folder_path\": f\"models/{dataset}\",\n", + " # \"config.random_seed\": {\"$in\": [12345, 54321, 98765]},\n", + " \"config.random_seed\": 54321,\n", + " \"config.train_fraction\": train_fraction,\n", + " \"config.train_mode\": 0, # Teacher forcing\n", + " \"config.dataset_version\": {\"$in\": [\"v4\", \"v5\"]}\n", + " }\n", + " runs = api.runs(\n", + " f\"{wandb_entity}/{wandb_project}\",\n", + " filters=filters\n", + " )\n", + " run = runs[0]\n", + " if len(runs)>0:\n", + " for r in runs:\n", + " if run.config[\"dataset_version\"] == \"v5\":\n", + " run = r\n", + " break\n", + "\n", + " # Get overall accuracy\n", + " acc_local = []\n", + " for run in runs:\n", + " overall_accuracy = run.summary[f\"test_best\"][f\"overall_accuracy_top{top_n}\"]\n", + " fi_overall_accuracy = run.summary[f\"frequency_informed_overall_accuracy_top_{top_n}\"]\n", + " overall_improvement = (overall_accuracy-fi_overall_accuracy)/(1-fi_overall_accuracy)\n", + " acc_local.append(overall_improvement)\n", + " overall_accuracies.append(np.mean(overall_improvement)*100)\n", + " n_train.append(run.config[\"n_train\"])\n", + " \n", + " # Add line to plot\n", + " label = LABELS[dataset]\n", + " ax.plot(\n", + " # TRAIN_FRACS,\n", + " np.array(n_train)/ 1e3, \n", + " overall_accuracies, \n", + " label=label, \n", + " linewidth=3.5, \n", + " marker=markers[i], \n", + " markersize=10,\n", + " color=colors[dataset]\n", + " )\n", + "\n", + "# Formatting\n", + "axis_fontsize = 16\n", + "heading_fontsize = 18\n", + "ax.legend(loc=\"lower right\", fontsize=axis_fontsize)\n", + "ax.set_xlabel(\"Train set size (thousands)\", fontsize=heading_fontsize)\n", + "ax.set_ylabel(\"Overall AIB (%)\", fontsize=heading_fontsize)\n", + "# ax.set_xticks(TRAIN_FRACS)\n", + "# ax.set_xticklabels(ax.gefontsize=axis_fontsize)\n", + "ax.tick_params(labelsize=axis_fontsize)\n", + "# ax.ticklabel_format( style='sci',scilimits=(4,4))\n", + "ylabels = np.arange(8, 22, 2)\n", + "ax.set_yticks(ylabels)\n", + "ax.set_yticklabels([f\"{ylabel:.0f}\" for ylabel in ylabels], fontsize=axis_fontsize)\n", + "fig.tight_layout()\n", + "fig.savefig(\"scaling_behavior.png\", dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "len(runs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "train_fraction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'with_trust_no_map'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Solvents & 34 // 52 // \\textcolor{lessgreen}{28\\%} & 34 // 53 // \\textcolor{lessgreen}{30\\%} & 24 // 45 // \\textcolor{lessgreen}{28\\%} & 23 // 42 // \\textcolor{lessgreen}{26\\%} \\\\ \n", + "Agents & 43 // 65 // \\textcolor{lessgreen}{39\\%} & 43 // 67 // \\textcolor{lessgreen}{43\\%} & 21 // 42 // \\textcolor{lessgreen}{27\\%} & 19 // 39 // \\textcolor{lessgreen}{25\\%} \\\\ \n", + "Solvents \\& Agents & 18 // 39 // \\textcolor{lessgreen}{26\\%} & 19 // 41 // \\textcolor{lessgreen}{27\\%} & 5 // 24 // \\textcolor{lessgreen}{21\\%} & 4 // 21 // \\textcolor{lessgreen}{18\\%} \\\\\n" + ] + } + ], + "source": [ + "DATASETS = [\"with_trust_with_map\",\"with_trust_no_map\", \"no_trust_no_map\", \"no_trust_with_map\"]\n", + "lines = [\"Solvents\", \"Agents\", \"Solvents \\& Agents\"]\n", + "top_n = 3\n", + "for dataset in DATASETS:\n", + " filters = {\n", + " \"state\": \"finished\",\n", + " \"config.output_folder_path\": f\"models/{dataset}\",\n", + " \"config.random_seed\": 54321,\n", + " \"config.train_fraction\": 1.0,\n", + " \"config.train_mode\": 0, # Teacher forcing\n", + " }\n", + " runs = api.runs(\n", + " f\"{wandb_entity}/{wandb_project}\",\n", + " filters=filters\n", + " )\n", + " assert len(runs) == 1\n", + " run = runs[0]\n", + "\n", + " # Get model solvent, agent and overall accuracy\n", + " test_best = run.summary[\"test_best\"]\n", + " solvent_accuracy = test_best[f\"solvent_accuracy_top{top_n}\"]\n", + " agent_accuracy = test_best[f\"three_agents_accuracy_top{top_n}\"]\n", + " overall_accuracy = test_best[f\"overall_accuracy_top{top_n}\"]\n", + "\n", + " # Get frequency informed solvent, agent and overall accuracy\n", + " fi_solvent_accuracy = run.summary[f\"frequency_informed_solvent_accuracy_top_{top_n}\"]\n", + " fi_agent_accuracy = run.summary[f\"frequency_informed_agent_accuracy_top_{top_n}\"]\n", + " fi_overall_accuracy = run.summary[f\"frequency_informed_overall_accuracy_top_{top_n}\"]\n", + "\n", + " # Improvement\n", + " solvent_improvement = (solvent_accuracy-fi_solvent_accuracy)/(1-fi_solvent_accuracy)\n", + " solvent_improvement_color = \"lessgreen\" if solvent_improvement>0 else \"red\"\n", + " agent_improvement = (agent_accuracy-fi_agent_accuracy)/(1-fi_agent_accuracy)\n", + " agent_improvement_color = \"lessgreen\" if agent_improvement>0 else \"red\"\n", + " overall_improvement = (overall_accuracy-fi_overall_accuracy)/(1-fi_overall_accuracy)\n", + " overall_improvement_color = \"lessgreen\" if overall_improvement>0 else \"red\"\n", + "\n", + " # Create table lines\n", + " lines[0] += f\" & {fi_solvent_accuracy*100:.0f} // {solvent_accuracy*100:.0f} // \\\\textcolor{{{solvent_improvement_color}}}{{{solvent_improvement*100:.0f}\\%}} \"\n", + " lines[1] += f\" & {fi_agent_accuracy*100:.0f} // {agent_accuracy*100:.0f} // \\\\textcolor{{{agent_improvement_color}}}{{{agent_improvement*100:.0f}\\%}} \"\n", + " lines[2] += f\" & {fi_overall_accuracy*100:.0f} // {overall_accuracy*100:.0f} // \\\\textcolor{{{overall_improvement_color}}}{{{overall_improvement*100:.0f}\\%}} \"\n", + "print(\"\\\\\\\\ \\n\".join(lines) + \"\\\\\\\\\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "DATASETS = [\"with_trust_with_map\",\"with_trust_no_map\", \"no_trust_no_map\", \"no_trust_with_map\"]\n", + "LABELS = {\n", + " \"with_trust_with_map\": r\"Labelling, rare $\\rightarrow$ other\",\n", + " \"with_trust_no_map\": r\"Labelling, rare $\\rightarrow$ delete rxn\",\n", + " \"no_trust_no_map\": r\"Reaction string, rare $\\rightarrow$ other\",\n", + " \"no_trust_with_map\": r\"Reaction string, rare $\\rightarrow$ delete rxn\",\n", + "}\n", + "TRAIN_FRACS = [0.2, 0.4, 0.6, 0.8, 1.0]\n", + "fig, ax = plt.subplots(1)\n", + "markers = [\"o\", \"d\", \"s\", \"^\"]\n", + "top_n = 3\n", + "for i, dataset in enumerate(DATASETS):\n", + " overall_accuracies = []\n", + " for train_fraction in TRAIN_FRACS:\n", + " filters = {\n", + " \"state\": \"finished\",\n", + " \"config.output_folder_path\": f\"models/{dataset}\",\n", + " # \"config.random_seed\": {\"$in\": [12345, 54321, 98765]},\n", + " \"config.random_seed\": 54321,\n", + " \"config.train_fraction\": train_fraction,\n", + " \"config.train_mode\": 0, # Teacher forcing\n", + " }\n", + " runs = api.runs(\n", + " f\"{wandb_entity}/{wandb_project}\",\n", + " filters=filters\n", + " )\n", + " assert len(runs) == 1\n", + " run = runs[0]\n", + "\n", + " # Get overall accuracy\n", + " acc_local = []\n", + " for run in runs:\n", + " overall_accuracy = run.summary[f\"test_best\"][f\"overall_accuracy_top{top_n}\"]\n", + " fi_overall_accuracy = run.summary[f\"frequency_informed_overall_accuracy_top_{top_n}\"]\n", + " overall_improvement = (overall_accuracy-fi_overall_accuracy)/(1-fi_overall_accuracy)\n", + " acc_local.append(overall_improvement)\n", + " overall_accuracies.append(np.mean(overall_improvement)*100)\n", + " \n", + " # Add line to plot\n", + " label = LABELS[dataset]\n", + " ax.plot(\n", + " TRAIN_FRACS, \n", + " overall_accuracies, \n", + " label=label, \n", + " linewidth=3.5, \n", + " marker=markers[i], \n", + " markersize=10,\n", + " )\n", + "\n", + "# Formatting\n", + "axis_fontsize = 16\n", + "heading_fontsize = 18\n", + "ax.legend(loc=\"upper left\", fontsize=axis_fontsize)\n", + "ax.set_xlabel(\"Training set fraction\", fontsize=heading_fontsize)\n", + "ax.set_ylabel(\"Overall Improvement\", fontsize=heading_fontsize)\n", + "ax.set_xticks(TRAIN_FRACS)\n", + "ax.set_xticklabels(TRAIN_FRACS, fontsize=axis_fontsize)\n", + "# ylabels = np.arange(0.1, 0.35, 0.05)\n", + "# ax.set_yticks(ylabels)\n", + "# ax.set_yticklabels([f\"{ylabel:0.2f}\" for ylabel in ylabels], fontsize=axis_fontsize)\n", + "fig.tight_layout()\n", + "fig.savefig(\"scaling_behavior.png\", dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(runs)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_fraction" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'with_trust_no_map'" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/update_wandb_runs.py b/notebooks/update_wandb_runs.py new file mode 100644 index 00000000..44ccf201 --- /dev/null +++ b/notebooks/update_wandb_runs.py @@ -0,0 +1,83 @@ +from condition_prediction.run import ConditionPrediction +import wandb +from tqdm import tqdm +from datetime import datetime +import matplotlib.pyplot as plt +import numpy as np +import pathlib +from tqdm import tqdm +import gc + +api = wandb.Api() +wandb_entity = "ceb-sre" +wandb_project = "orderly" + +# Loop through all relevant runs on wandb to get run_ids, datasets and random seeds +# For each rerun the conditionprediction with skip_training=True and resume=True +DATASETS = [ + "with_trust_with_map", + "with_trust_no_map", + "no_trust_no_map", + "no_trust_with_map", +] +BASE_PATH = pathlib.Path("/project/studios/orderly-preprocessing/ORDerly/") +DATASETS_PATH = BASE_PATH / "data/orderly/datasets/" +MODEL_PATH = pathlib.Path("models/") +configs = [] +# for random_seed in [98765]: +for dataset in DATASETS: + filters = { + "state": "finished", + "config.output_folder_path": { + "$in": [ + f"models/{dataset}", + str(MODEL_PATH / dataset), + f"/Users/Kobi/Documents/Research/phd_code/ORDerly/models/{dataset}", + ], + }, + # "config.random_seed": random_seed, + # "config.train_fraction": 1.0, + # Switching back to v3 for the paper + "config.dataset_version": "v3", + # "config.train_mode": 0, # Teacher forcing + } + runs = api.runs(f"{wandb_entity}/{wandb_project}", filters=filters) + # if not len(runs) == 5: # For 5 training fractions + # raise ValueError(f"Not 5 runs for {dataset} (found {len(runs)}, seed {random_seed})") + + for run in runs: + config = dict(run.config) + train_data_path = pathlib.Path( + f"{DATASETS_PATH}/orderly_{dataset}_train.parquet" + ) + test_data_path = pathlib.Path(f"{DATASETS_PATH}/orderly_{dataset}_test.parquet") + fp_directory = train_data_path.parent / "fingerprints" + train_fp_path = fp_directory / (train_data_path.stem + ".npy") + test_fp_path = fp_directory / (test_data_path.stem + ".npy") + output_folder_path = MODEL_PATH / dataset + output_folder_path.mkdir(parents=True, exist_ok=True) + tags = dataset.split("_") + tags = [f"{tags[0]}_{tags[1]}", f"{tags[2]}_{tags[3]}"] + config.update( + { + "train_data_path": train_data_path, + "test_data_path": test_data_path, + "train_fp_path": train_fp_path, + "test_fp_path": test_fp_path, + "output_folder_path": output_folder_path, + "skip_training": True, + "resume": True, + "resume_from_best": True, + "generate_fingerprints": False, + "wandb_run_id": run.id, + "wandb_tags": tags, + } + ) + configs.append(config) + del config["n_val"] + del config["n_test"] + del config["n_train"] + instance = ConditionPrediction(**config) + instance.run_model_arguments() + wandb.finish() + gc.collect()