-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44 from ECRL/dev
Better implementation of TensorFlow 2.0
- Loading branch information
Showing
22 changed files
with
1,087 additions
and
123 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from ecnet.server import Server | ||
__version__ = '3.3.0' | ||
__version__ = '3.3.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
|
@@ -13,13 +13,12 @@ | |
from re import compile, IGNORECASE | ||
|
||
# 3rd party imports | ||
from h5py import File | ||
from numpy import array, string_, zeros | ||
from tensorflow import config, Tensor | ||
from numpy import array | ||
from tensorflow import config | ||
from tensorflow.keras.callbacks import EarlyStopping | ||
from tensorflow.keras.layers import Dense | ||
from tensorflow.keras.losses import MeanSquaredError | ||
from tensorflow.keras.models import Model | ||
from tensorflow.keras.models import Sequential, load_model | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
# ECNet imports | ||
|
@@ -45,7 +44,7 @@ def check_h5(filename: str): | |
) | ||
|
||
|
||
class MultilayerPerceptron(Model): | ||
class MultilayerPerceptron: | ||
|
||
def __init__(self, filename: str = 'model.h5'): | ||
''' MultilayerPerceptron: Feed-forward neural network; variable number | ||
|
@@ -56,10 +55,9 @@ def __init__(self, filename: str = 'model.h5'): | |
filename (str): filename/path for the model (default: `model.h5`) | ||
''' | ||
|
||
super(MultilayerPerceptron, self).__init__() | ||
check_h5(filename) | ||
self._filename = filename | ||
self._layers = [] | ||
self._model = Sequential() | ||
|
||
def add_layer(self, num_neurons: int, activation: str, | ||
input_dim: int = None): | ||
|
@@ -76,30 +74,16 @@ def add_layer(self, num_neurons: int, activation: str, | |
neurons), should be kept as `None` (default value) | ||
''' | ||
|
||
if len(self._layers) == 0: | ||
if len(self._model.layers) == 0: | ||
if input_dim is None: | ||
raise ValueError('First layer must have input_dim specified') | ||
|
||
self._layers.append(Dense( | ||
self._model.add(Dense( | ||
units=num_neurons, | ||
activation=activation, | ||
input_shape=(input_dim,) | ||
input_dim=input_dim | ||
)) | ||
|
||
def call(self, x: Tensor) -> Tensor: | ||
''' call: used by Model.fit (parent) to perform feed-forward operations | ||
Args: | ||
x (tf.Tensor): data fed into first layer | ||
Returns: | ||
tf.Tensor: data resulting from last layer | ||
''' | ||
|
||
for layer in self._layers: | ||
x = layer(x) | ||
return x | ||
|
||
def fit(self, l_x: array, l_y: array, v_x: array = None, v_y: array = None, | ||
epochs: int = 1500, lr: float = 0.001, beta_1: float = 0.9, | ||
beta_2: float = 0.999, epsilon: float = 0.0000001, | ||
|
@@ -134,25 +118,25 @@ def fit(self, l_x: array, l_y: array, v_x: array = None, v_y: array = None, | |
epoch | ||
''' | ||
|
||
self.compile(optimizer=Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, | ||
epsilon=epsilon, | ||
decay=decay), | ||
loss=MeanSquaredError()) | ||
self._model.compile(optimizer=Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, | ||
epsilon=epsilon, | ||
decay=decay), | ||
loss=MeanSquaredError()) | ||
|
||
if v_x is not None and v_y is not None: | ||
|
||
callback = EarlyStopping(monitor='val_loss', patience=patience, | ||
restore_best_weights=True) | ||
history = super().fit(l_x, l_y, batch_size=batch_size, | ||
epochs=epochs, verbose=v, | ||
callbacks=[callback], | ||
validation_data=(v_x, v_y)) | ||
history = self._model.fit(l_x, l_y, batch_size=batch_size, | ||
epochs=epochs, verbose=v, | ||
callbacks=[callback], | ||
validation_data=(v_x, v_y)) | ||
return (history.history['loss'], history.history['val_loss']) | ||
|
||
else: | ||
|
||
history = super().fit(l_x, l_y, batch_size=batch_size, | ||
epochs=epochs, verbose=v) | ||
history = self._model.fit(l_x, l_y, batch_size=batch_size, | ||
epochs=epochs, verbose=v) | ||
return (history.history['loss'], [None for _ in range(epochs)]) | ||
|
||
def use(self, x: array) -> array: | ||
|
@@ -165,7 +149,7 @@ def use(self, x: array) -> array: | |
np.array: predicted values | ||
''' | ||
|
||
return self.predict(x) | ||
return self._model.predict(x) | ||
|
||
def save(self, filename: str = None): | ||
''' save: saves the model weights, architecture to either the filename/ | ||
|
@@ -178,15 +162,7 @@ def save(self, filename: str = None): | |
if filename is None: | ||
filename = self._filename | ||
check_h5(filename) | ||
self.save_weights(filename, save_format='h5') | ||
input_size = self.layers[0].get_config()['batch_input_shape'][1] | ||
layer_sizes = [l.get_config()['units'] for l in self.layers] | ||
layer_activ = [l.get_config()['activation'] for l in self.layers] | ||
with File(filename, 'a') as hf: | ||
hf['mlp_input_size'] = input_size | ||
hf['mlp_layer_sizes'] = layer_sizes | ||
hf['mlp_layer_activ'] = string_(layer_activ) | ||
hf.close() | ||
self._model.save(filename, include_optimizer=False) | ||
logger.log('debug', 'Model saved to {}'.format(filename), | ||
call_loc='MLP') | ||
|
||
|
@@ -201,16 +177,6 @@ def load(self, filename: str = None): | |
|
||
if filename is None: | ||
filename = self._filename | ||
with File(filename, 'r') as hf: | ||
input_size = hf.get('mlp_input_size').value | ||
layer_sizes = hf.get('mlp_layer_sizes').value | ||
layer_activ = hf.get('mlp_layer_activ').value | ||
hf.close() | ||
self.add_layer(layer_sizes[0], layer_activ[0].decode('ascii'), | ||
input_size) | ||
for idx, layer in enumerate(layer_sizes[1:]): | ||
self.add_layer(layer, layer_activ[idx].decode('ascii')) | ||
self.build(input_shape=(None, input_size)) | ||
self.load_weights(filename) | ||
self._model = load_model(filename, compile=False) | ||
logger.log('debug', 'Model loaded from {}'.format(filename), | ||
call_loc='MLP') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/training.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for project training (multiprocessed training) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/tuning.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/fitness functions for tuning hyperparameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/database.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for creating ECNet-formatted databases | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/plotting.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for creating various plots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,22 +2,92 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/project.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for predicting data using pre-existing .prj files | ||
# | ||
|
||
# Stdlib imports | ||
from datetime import datetime | ||
from os import remove | ||
from shutil import rmtree | ||
from os import walk | ||
from os.path import basename, join | ||
from re import compile, IGNORECASE | ||
from tempfile import TemporaryDirectory | ||
from warnings import warn | ||
from zipfile import ZipFile | ||
|
||
# 3rd party imports | ||
from alvadescpy import smiles_to_descriptors | ||
from numpy import asarray, mean | ||
from padelpy import from_smiles | ||
|
||
# ECNet imports | ||
from ecnet import Server | ||
from ecnet.utils.data_utils import DataFrame | ||
from ecnet.utils.logging import logger | ||
from ecnet.tools.database import create_db | ||
from ecnet.models.mlp import MultilayerPerceptron | ||
from ecnet.utils.server_utils import open_config, open_df | ||
|
||
CONFIG_RE = compile(r'^.*\.yml$', IGNORECASE) | ||
MODEL_RE = compile(r'^.*\.h5$', IGNORECASE) | ||
|
||
|
||
class TrainedProject: | ||
|
||
def __init__(self, filename: str): | ||
''' TrainedProject: loads a trained ECNet project, including last-used | ||
DataFrame, configuration .yml file, and all trained models | ||
Args: | ||
filename (str): name/path of the trained .prj file | ||
''' | ||
|
||
self._df = None | ||
self._config = None | ||
self._models = [] | ||
|
||
with ZipFile(filename, 'r') as zf: | ||
prj_zip = zf.namelist() | ||
with TemporaryDirectory() as tmpdirname: | ||
zf.extractall(tmpdirname) | ||
prj_dirname = join(tmpdirname, basename( | ||
filename.replace('.prj', '') | ||
)) | ||
self._df = open_df(join(prj_dirname, 'data.d')) | ||
for root, _, files in walk(prj_dirname): | ||
for f in files: | ||
if MODEL_RE.match(f) is not None: | ||
_model = MultilayerPerceptron(join(root, f)) | ||
_model.load() | ||
self._models.append(_model) | ||
elif CONFIG_RE.match(f) is not None: | ||
self._config = open_config(join(root, f)) | ||
|
||
def use(self, smiles: list, backend: str = 'padel'): | ||
''' use: uses the trained project to predict values for supplied | ||
molecules | ||
Args: | ||
smiles (list): list of SMILES strings to predict for | ||
backend (str): backend software to use for QSPR generation; `padel` | ||
or `alvadesc`; default = `padel`; alvadesc requries valid | ||
license | ||
Returns: | ||
numpy.array: predicted values | ||
''' | ||
|
||
if backend == 'alvadesc': | ||
mols = [smiles_to_descriptors(s) for s in smiles] | ||
for mol in mols: | ||
for key in list(mol.keys()): | ||
if mol[key] == 'na': | ||
mol[key] = 0 | ||
elif backend == 'padel': | ||
mols = [from_smiles(s) for s in smiles] | ||
else: | ||
raise ValueError('Unknown backend software: {}'.format(backend)) | ||
return mean([model.use(asarray( | ||
[[float(mol[name]) for name in self._df._input_names] | ||
for mol in mols] | ||
)) for model in self._models], axis=0) | ||
|
||
|
||
def predict(smiles: list, prj_file: str, results_file: str = None, | ||
|
@@ -29,23 +99,19 @@ def predict(smiles: list, prj_file: str, results_file: str = None, | |
smiles (str): SMILES strings for molecules | ||
prj_file (str): path to ECNet .prj file | ||
results_file (str): if not none, saves results to this CSV file | ||
(WARNING: depricated, no longer saves to file) | ||
backend (str): `padel` (default) or `alvadesc`, depending on the data | ||
your project was trained with | ||
Returns: | ||
list: predicted values | ||
''' | ||
|
||
sv = Server(prj_file=prj_file) | ||
|
||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')[:-3] | ||
create_db(smiles, '{}.csv'.format(timestamp), backend=backend) | ||
new_data = DataFrame('{}.csv'.format(timestamp)) | ||
new_data.set_inputs(sv._df._input_names) | ||
new_data.create_sets() | ||
sv._df = new_data | ||
sv._sets = sv._df.package_sets() | ||
results = sv.use(output_filename=results_file) | ||
remove('{}.csv'.format(timestamp)) | ||
rmtree(prj_file.replace('.prj', '')) | ||
return results | ||
if results_file is not None: | ||
class NotImplementedWarning(UserWarning): | ||
pass | ||
warn('`predict` no longer saves directly to a file, results are only' | ||
' returned to the user', NotImplementedWarning) | ||
|
||
project = TrainedProject(prj_file) | ||
return project.use(smiles, backend) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/data_utils.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for loading data, saving data, saving results | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/error_utils.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for error calculations | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/logging.py | ||
# v.3.3.0 | ||
# v.3.3.1 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains logger used by ECNet | ||
|
Oops, something went wrong.