-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
122 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,25 +2,27 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
# | ||
|
||
# Stdlib imports | ||
from os import environ | ||
from re import compile, IGNORECASE | ||
|
||
# 3rd party imports | ||
from h5py import File | ||
from numpy import array, string_, zeros | ||
|
||
from tensorflow import config, Tensor | ||
from tensorflow.keras.callbacks import EarlyStopping | ||
from tensorflow.keras.layers import Dense | ||
from tensorflow.keras.losses import MeanSquaredError | ||
from tensorflow.keras.models import Model | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
# ECNet imports | ||
from ecnet.utils.logging import logger | ||
|
||
environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
# network model creation, data hand-off to models, prediction error | ||
|
@@ -26,8 +26,8 @@ | |
|
||
class Server: | ||
|
||
def __init__(self, model_config: str='config.yml', prj_file: str=None, | ||
num_processes: int=1): | ||
def __init__(self, model_config: str = 'config.yml', prj_file: str = None, | ||
num_processes: int = 1): | ||
'''Server object: handles data loading, model creation, data-to-model | ||
hand-off, data input parameter selection, hyperparameter tuning | ||
|
@@ -69,8 +69,8 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None, | |
self._vars = default_config() | ||
save_config(self._vars, self._cf_file) | ||
|
||
def load_data(self, filename: str, random: bool=False, split: list=None, | ||
normalize: bool=False): | ||
def load_data(self, filename: str, random: bool = False, | ||
split: list = None, normalize: bool = False): | ||
'''Loads data from an ECNet-formatted CSV database | ||
Args: | ||
|
@@ -90,8 +90,8 @@ def load_data(self, filename: str, random: bool=False, split: list=None, | |
self._df.create_sets(random, split) | ||
self._sets = self._df.package_sets() | ||
|
||
def create_project(self, project_name: str, num_pools: int=1, | ||
num_candidates: int=1): | ||
def create_project(self, project_name: str, num_pools: int = 1, | ||
num_candidates: int = 1): | ||
'''Creates folder hierarchy for a new project | ||
Args: | ||
|
@@ -111,8 +111,8 @@ def create_project(self, project_name: str, num_pools: int=1, | |
logger.log('debug', 'Number of candidates/pool: {}'.format( | ||
num_candidates), call_loc='PROJECT') | ||
|
||
def limit_inputs(self, limit_num: int, num_estimators: int=None, | ||
eval_set: str='learn', output_filename: str=None, | ||
def limit_inputs(self, limit_num: int, num_estimators: int = None, | ||
eval_set: str = 'learn', output_filename: str = None, | ||
**kwargs) -> list: | ||
'''Selects `limit_num` influential input parameters using random | ||
forest regression | ||
|
@@ -149,9 +149,9 @@ def limit_inputs(self, limit_num: int, num_estimators: int=None, | |
return result | ||
|
||
def tune_hyperparameters(self, num_employers: int, num_iterations: int, | ||
shuffle: bool=None, split: list=None, | ||
validate: bool=True, eval_set: str=None, | ||
eval_fn: str='rmse', epochs: int=300): | ||
shuffle: bool = None, split: list = None, | ||
validate: bool = True, eval_set: str = None, | ||
eval_fn: str = 'rmse', epochs: int = 300): | ||
'''Tunes neural network learning hyperparameters using an artificial | ||
bee colony algorithm; tuned hyperparameters are saved to Server's | ||
model configuration file | ||
|
@@ -185,10 +185,10 @@ def tune_hyperparameters(self, num_employers: int, num_iterations: int, | |
) | ||
save_config(self._vars, self._cf_file) | ||
|
||
def train(self, shuffle: str=None, split: list=None, retrain: bool=False, | ||
validate: bool=False, selection_set: str=None, | ||
selection_fn: str='rmse', model_filename: str='model.h5', | ||
verbose=0) -> list: | ||
def train(self, shuffle: str = None, split: list = None, | ||
retrain: bool = False, validate: bool = False, | ||
selection_set: str = None, selection_fn: str = 'rmse', | ||
model_filename: str = 'model.h5', verbose: int = 0) -> tuple: | ||
'''Trains neural network(s) using currently-loaded data; single NN if | ||
no project is created, all candidates if created | ||
|
@@ -210,8 +210,8 @@ def train(self, shuffle: str=None, split: list=None, retrain: bool=False, | |
model only) | ||
Returns: | ||
list: if training single model, returns list of learn/valid losses, | ||
else None | ||
tuple: if training single model, returns tuple of learn/valid | ||
losses, else None | ||
''' | ||
|
||
if self._prj_name is None: | ||
|
@@ -246,8 +246,8 @@ def train(self, shuffle: str=None, split: list=None, retrain: bool=False, | |
) | ||
return None | ||
|
||
def use(self, dset: str=None, output_filename: str=None, | ||
model_filename: str='model.h5') -> list: | ||
def use(self, dset: str = None, output_filename: str = None, | ||
model_filename: str = 'model.h5') -> list: | ||
'''Uses trained neural network(s) to predict for specified set; single | ||
NN if no project created, best pool candidates if created | ||
|
@@ -277,8 +277,8 @@ def use(self, dset: str=None, output_filename: str=None, | |
call_loc='USE') | ||
return results | ||
|
||
def errors(self, *args, dset: str=None, | ||
model_filename: str='model.h5') -> dict: | ||
def errors(self, *args, dset: str = None, | ||
model_filename: str = 'model.h5') -> dict: | ||
'''Obtains various errors for specified set | ||
Args: | ||
|
@@ -304,8 +304,8 @@ def errors(self, *args, dset: str=None, | |
logger.log('debug', 'Errors: {}'.format(errors), call_loc='ERRORS') | ||
return errors | ||
|
||
def save_project(self, filename: str=None, clean_up: bool=True, | ||
del_candidates: bool=False): | ||
def save_project(self, filename: str = None, clean_up: bool = True, | ||
del_candidates: bool = False): | ||
'''Saves current state of project to a .prj file | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
# | ||
|
@@ -21,8 +21,8 @@ | |
from ecnet.utils.server_utils import get_x, get_y | ||
|
||
|
||
def limit_rforest(df: DataFrame, limit_num: int, num_estimators: int=None, | ||
num_processes: int=1, eval_set: str='learn', | ||
def limit_rforest(df: DataFrame, limit_num: int, num_estimators: int = None, | ||
num_processes: int = 1, eval_set: str = 'learn', | ||
**kwargs) -> list: | ||
'''Uses random forest regression to select input parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/training.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for project training (multiprocessed training) | ||
# | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/tuning.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/fitness functions for tuning hyperparameters | ||
# | ||
|
@@ -22,10 +22,10 @@ | |
|
||
|
||
def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, | ||
num_iterations: int, num_processes: int=1, | ||
shuffle: str=None, split: list=None, | ||
validate: bool=True, eval_set: str=None, | ||
eval_fn: str='rmse', epochs: int=300) -> dict: | ||
num_iterations: int, num_processes: int = 1, | ||
shuffle: str = None, split: list = None, | ||
validate: bool = True, eval_set: str = None, | ||
eval_fn: str = 'rmse', epochs: int = 300) -> dict: | ||
'''Tunes neural network learning/architecture hyperparameters | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/database.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for creating ECNet-formatted databases | ||
# | ||
|
@@ -20,7 +20,7 @@ | |
|
||
try: | ||
import pybel | ||
except: | ||
except ImportError: | ||
pybel = None | ||
|
||
|
||
|
@@ -35,9 +35,9 @@ def __init__(self, id): | |
self.inputs = None | ||
|
||
|
||
def create_db(smiles: list, db_name: str, targets: list=None, | ||
id_prefix: str='', extra_strings: dict={}, backend: str='padel', | ||
convert_mdl: bool=False): | ||
def create_db(smiles: list, db_name: str, targets: list = None, | ||
id_prefix: str = '', extra_strings: dict = {}, | ||
backend: str = 'padel', convert_mdl: bool = False): | ||
''' create_db: creates an ECNet-formatted database from SMILES strings | ||
using either PaDEL-Descriptor or alvaDesc software; using alvaDesc | ||
requires a valid installation/license of alvaDesc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/plotting.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for creating various plots | ||
# | ||
|
@@ -18,9 +18,10 @@ | |
|
||
class ParityPlot: | ||
|
||
def __init__(self, title: str='Parity Plot', | ||
x_label: str='Experimental Value', | ||
y_label: str='Predicted Value', font: str='Times New Roman'): | ||
def __init__(self, title: str = 'Parity Plot', | ||
x_label: str = 'Experimental Value', | ||
y_label: str = 'Predicted Value', | ||
font: str = 'Times New Roman'): | ||
''' ParityPlot: creates a plot of predicted values vs. experimental | ||
data relative to a 1:1 parity line | ||
|
@@ -39,7 +40,7 @@ def __init__(self, title: str='Parity Plot', | |
self._min_val = 0 | ||
self._labels = None | ||
|
||
def add_series(self, x_vals, y_vals, name: str=None, color: str=None): | ||
def add_series(self, x_vals, y_vals, name: str = None, color: str = None): | ||
''' Adds data to the plot | ||
Args: | ||
|
@@ -67,7 +68,7 @@ def add_series(self, x_vals, y_vals, name: str=None, color: str=None): | |
if y_min < self._min_val: | ||
self._min_val = y_min | ||
|
||
def add_error_bars(self, error: float, label: str=None): | ||
def add_error_bars(self, error: float, label: str = None): | ||
''' Adds error bars, +/- the error relative to the 1:1 parity line | ||
Args: | ||
|
@@ -96,7 +97,7 @@ def save(self, filename: str): | |
self._add_parity_line() | ||
plt.savefig(filename) | ||
|
||
def _add_parity_line(self, offset: float=0.0): | ||
def _add_parity_line(self, offset: float = 0.0): | ||
''' Adds a 1:1 parity line | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/project.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for predicting data using pre-existing .prj files | ||
# | ||
|
@@ -20,8 +20,8 @@ | |
from ecnet.tools.database import create_db | ||
|
||
|
||
def predict(smiles: list, prj_file: str, results_file: str=None, | ||
backend: str='padel') -> list: | ||
def predict(smiles: list, prj_file: str, results_file: str = None, | ||
backend: str = 'padel') -> list: | ||
''' predict: predicts values for supplied molecules (SMILES strings) using | ||
pre-existing ECNet project (.prj) file | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/data_utils.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for loading data, saving data, saving results | ||
# | ||
|
@@ -33,8 +33,8 @@ def __init__(self): | |
class PackagedData: | ||
|
||
def __init__(self): | ||
'''PackagedData object: contains lists of input and target data for data | ||
set assignments | ||
'''PackagedData object: contains lists of input and target data for | ||
data set assignments | ||
''' | ||
|
||
self.learn_x = [] | ||
|
@@ -112,7 +112,7 @@ def __len__(self): | |
|
||
return len(self.data_points) | ||
|
||
def create_sets(self, random: bool=False, split: list=[0.7, 0.2, 0.1]): | ||
def create_sets(self, random: bool = False, split: list = [0.7, 0.2, 0.1]): | ||
'''Creates learning, validation and test sets | ||
Args: | ||
|
@@ -167,9 +167,9 @@ def create_sets(self, random: bool=False, split: list=[0.7, 0.2, 0.1]): | |
logger.log('debug', 'Number of entries in test set: {}'.format( | ||
len(self.test_set)), call_loc='DF') | ||
|
||
def create_sorted_sets(self, sort_str: str, split: list=[0.7, 0.2, 0.1]): | ||
'''Creates random learn, validate and test sets, ensuring data points with | ||
the supplied sort string are split proportionally between the sets | ||
def create_sorted_sets(self, sort_str: str, split: list = [0.7, 0.2, 0.1]): | ||
'''Creates random learn, validate and test sets, ensuring data points | ||
with the supplied sort string are split proportionally between the sets | ||
Args: | ||
sort_str (str): database STRING value used to sort data points | ||
|
@@ -239,7 +239,7 @@ def normalize(self): | |
(float(getattr(pt, inp)) - v_min) / (v_max - v_min) | ||
) | ||
|
||
def shuffle(self, sets: str='all', split: list=[0.7, 0.2, 0.1]): | ||
def shuffle(self, sets: str = 'all', split: list = [0.7, 0.2, 0.1]): | ||
'''Shuffles learning, validation and test sets or learning and | ||
validation sets | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/error_utils.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for error calculations | ||
# | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/logging.py | ||
# v.3.2.3 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# v.3.3.0 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains logger used by ECNet | ||
# | ||
|
Oops, something went wrong.