Skip to content

Commit

Permalink
ENH Switch to ONNX for model storage
Browse files Browse the repository at this point in the history
This will avoid warnings related to version mismatches (e.g., #59)

Retrained models to fix off by one error in feature indexing too

This changes results, but probabilities are 0.89 correlated (Spearman)
with previous version

Alas, macrel now requires a recent(ish) version of onnxruntime which
means that Python 3.6 and 3.7 are no longer tested. OTOH, we now test on
3.12 as well
  • Loading branch information
luispedro committed Sep 23, 2024
1 parent 83fe40b commit 3362085
Show file tree
Hide file tree
Showing 22 changed files with 968 additions and 949 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version:
- "3.6"
- "3.7"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"

steps:
- name: Checking code
Expand All @@ -30,8 +29,7 @@ jobs:
shell: bash -l {0}
run : |
conda install -c bioconda -c conda-forge \
ngless pyrodigal megahit paladin pandas requests atomium tzlocal \
"scikit-learn<1.3.0" "joblib<1.3.0"
ngless pyrodigal megahit paladin pandas requests atomium tzlocal onnxruntime
conda install pytest
pip install .
- name: Test with pytest
Expand Down
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Unreleased
* Switch to ONNX for model storage

Version 1.5.0 2024-09-20
* Add support for local searching
* Slightly change output format for AMPSphere matching
Expand Down
24 changes: 15 additions & 9 deletions macrel/AMP_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pandas as pd
import numpy as np
import pickle
import gzip
import logging

Expand All @@ -20,27 +19,34 @@ def predict(model1, model2, data, keep_negatives=False):
-------
Table with prediction labels
'''
model1 = pickle.load(gzip.open(model1, 'rb'))
model2 = pickle.load(gzip.open(model2, 'rb'))
import onnxruntime as rt
with gzip.open(model1, 'rb') as f:
model1 = rt.InferenceSession(f.read(), providers=["CPUExecutionProvider"])

with gzip.open(model2, 'rb') as f:
model2 = rt.InferenceSession(f.read(), providers=["CPUExecutionProvider"])

# limit should be 100, but let's give the user 10% leeway before we warn them
if data.sequence.map(len).max() >= 110:
logger = logging.getLogger('macrel')
logger.warning('Warning: some input sequences are longer than 100 amino-acids.'
' Macrel models were developed and tested for short peptides (<100 amino-acids).'
' Applying them on longer ones will return a result, but these should be considered less reliable.')
features = data.iloc[:, 3:]
features = data.iloc[:, 2:]

# predict_proba will raise an Exception if passed empty arguments
if len(features):
amp_prob = model1.predict_proba(features).T[0]
hemo_prob = model2.predict_proba(features).T[0]
[amp_prob] = model1.run(['output_probability'], {'input_features': features.values.astype(np.float32)})
[hemo_prob] = model2.run(['output_probability'], {'input_features': features.values.astype(np.float32)})

else:
amp_prob = np.array([])
hemo_prob = np.array([])
is_amp = np.where(amp_prob > .5, model1.classes_[0], model1.classes_[1])
is_amp = (is_amp == "AMP")
is_hemo = np.where(hemo_prob > .5, model2.classes_[0], model2.classes_[1])
amp_prob = np.array([x['AMP'] for x in amp_prob])
is_amp = (amp_prob > .5)

hemo_prob = np.array([x['Hemo'] for x in hemo_prob])
is_hemo = np.where(hemo_prob > .5, 'Hemo', 'NonHemo')

final = pd.DataFrame({'Sequence': data['sequence'],
'AMP_family':
Expand Down
Binary file added macrel/data/models/AMP.onnx.gz
Binary file not shown.
Binary file removed macrel/data/models/AMP.pkl.gz
Binary file not shown.
Binary file added macrel/data/models/Hemo.onnx.gz
Binary file not shown.
Binary file removed macrel/data/models/Hemo.pkl.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion macrel/macrel_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.0'
__version__ = '1.6.0.dev0'
4 changes: 2 additions & 2 deletions macrel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def do_predict(args, tdir):
import gzip
fs = fasta_features(args.fasta_file)
prediction = predict(
data_file("models/AMP.pkl.gz"),
data_file("models/Hemo.pkl.gz"),
data_file("models/AMP.onnx.gz"),
data_file("models/Hemo.onnx.gz"),
fs,
args.keep_negatives)
ofile = path.join(args.output, args.outtag + '.prediction.gz')
Expand Down
15 changes: 8 additions & 7 deletions macrel/tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@

def test_predict():
fs = AMP_features.fasta_features('tests/peptides/expep.faa.gz')
fsp = AMP_predict.predict( data_file("models/AMP.pkl.gz"),
data_file("models/Hemo.pkl.gz"),
fsp = AMP_predict.predict(data_file("models/AMP.onnx.gz"),
data_file("models/Hemo.onnx.gz"),
fs)
fsn = AMP_predict.predict( data_file("models/AMP.pkl.gz"),
data_file("models/Hemo.pkl.gz"),
fsn = AMP_predict.predict( data_file("models/AMP.onnx.gz"),
data_file("models/Hemo.onnx.gz"),
fs, keep_negatives=True)
assert len(fsp) < len(fsn)
assert not np.all(fsn.is_AMP)

def test_predict_very_short():
'''Test the prediction of very short sequences (used to crash)'''
fs = AMP_features.fasta_features(
path.join(path.dirname(__file__),
'data',
'very_short.faa'))
assert len(fs) == 2
fsn = AMP_predict.predict(data_file("models/AMP.pkl.gz"),
data_file("models/Hemo.pkl.gz"),
fsn = AMP_predict.predict(data_file("models/AMP.onnx.gz"),
data_file("models/Hemo.onnx.gz"),
fs, keep_negatives=True)
assert not np.any(fsn.is_AMP)
assert np.any(fsn.is_AMP)
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'macrel': 'macrel/',
}
package_data = {
'macrel': ['data/*', 'data/scripts/*.ngl', 'data/models/*.pkl.gz'],
'macrel': ['data/*', 'data/scripts/*.ngl', 'data/models/*.onnx.gz'],
}

packages = setuptools.find_packages()
Expand All @@ -46,12 +46,11 @@
'Intended Audience :: Science/Research',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Operating System :: OS Independent',
'License :: OSI Approved :: MIT License',
]
Expand All @@ -71,7 +70,7 @@
package_data = package_data,
zip_safe = False, # We want the model files to be installed as files
install_requires=[
'scikit-learn',
'onnxruntime',
'pandas',
'requests',
'atomicwrites'
Expand Down
8 changes: 5 additions & 3 deletions tests/contigs.cluster/expected.prediction
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Prediction from macrel v1.5.0
# Prediction from macrel v1.6.0.dev0
Access Sequence AMP_family AMP_probability Hemolytic Hemolytic_probability
smORF_2 RFLIKMVKVNLMNGKLIRKISLM CLP 0.634 Hemo 0.871
smORF_19 FFNDGKGTIYYGIKKYFRIYF CLP 0.673 Hemo 0.822
smORF_0 KVIKKVVAALMVLGALAALTVGVVLKPGRKGDET CLP 0.554 Hemo 0.772
smORF_2 RFLIKMVKVNLMNGKLIRKISLM CLP 0.733 Hemo 0.921
smORF_19 FFNDGKGTIYYGIKKYFRIYF CLP 0.634 Hemo 0.762
smORF_22 TIVVKKVPKCLRGIVKLLFGIKEKWYEKRGYSYSLYFLFVYLL CDP 0.505 Hemo 0.871
2 changes: 1 addition & 1 deletion tests/contigs.nosmorfs/expected.percontigs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Prediction from macrel v1.5.0
# Prediction from macrel v1.6.0.dev0
# Macrel calculated for the sample a density of 0.000 AMPs / Mbp.
contig length ORFs smORFs AMPs
scaffold2530_2_MH0058 1324 1 0 0
2 changes: 1 addition & 1 deletion tests/contigs.nosmorfs/expected.prediction
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Prediction from macrel v1.5.0
# Prediction from macrel v1.6.0.dev0
Access Sequence AMP_family AMP_probability Hemolytic Hemolytic_probability
8 changes: 4 additions & 4 deletions tests/contigs/expected.percontigs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Prediction from macrel v1.5.0
# Macrel calculated for the sample a density of 45.062 AMPs / Mbp.
# Prediction from macrel v1.6.0.dev0
# Macrel calculated for the sample a density of 90.125 AMPs / Mbp.
contig length ORFs smORFs AMPs
C4060843_1_MH0058 518 1 1 0
C4067509_1_MH0058 534 1 1 0
Expand All @@ -11,7 +11,7 @@ scaffold107406_2_MH0058 1401 2 1 0
scaffold16564_1_MH0058 992 2 2 0
scaffold20234_2_MH0058 2926 4 1 0
scaffold24504_2_MH0058 505 1 1 0
scaffold2530_2_MH0058 717 2 2 0
scaffold2530_2_MH0058 717 2 2 1
scaffold30291_4_MH0058 824 2 1 0
scaffold33693_17_MH0058 3481 2 1 1
scaffold34596_7_MH0058 1345 1 0 0
Expand All @@ -20,7 +20,7 @@ scaffold7019_2_MH0058 5218 6 3 0
scaffold75223_9_MH0058 3597 2 2 0
scaffold75334_1_MH0058 3424 1 1 1
scaffold76045_5_MH0058 960 3 3 0
scaffold77554_3_MH0058 6086 9 4 0
scaffold77554_3_MH0058 6086 9 4 1
scaffold8449_1_MH0058 1037 2 1 0
scaffold90770_1_MH0058 1031 2 2 0
scaffold95393_2_MH0058 995 2 1 0
Expand Down
8 changes: 5 additions & 3 deletions tests/contigs/expected.prediction
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Prediction from macrel v1.5.0
# Prediction from macrel v1.6.0.dev0
Access Sequence AMP_family AMP_probability Hemolytic Hemolytic_probability
scaffold75334_1_MH0058_1 RFLIKMVKVNLMNGKLIRKISLM CLP 0.634 Hemo 0.871
scaffold33693_17_MH0058_2 FFNDGKGTIYYGIKKYFRIYF CLP 0.673 Hemo 0.822
scaffold2530_2_MH0058_1 KVIKKVVAALMVLGALAALTVGVVLKPGRKGDET CLP 0.554 Hemo 0.772
scaffold75334_1_MH0058_1 RFLIKMVKVNLMNGKLIRKISLM CLP 0.733 Hemo 0.921
scaffold33693_17_MH0058_2 FFNDGKGTIYYGIKKYFRIYF CLP 0.634 Hemo 0.762
scaffold77554_3_MH0058_7 TIVVKKVPKCLRGIVKLLFGIKEKWYEKRGYSYSLYFLFVYLL CDP 0.505 Hemo 0.871
Loading

0 comments on commit 3362085

Please sign in to comment.