Skip to content

Commit

Permalink
Merge branch 'ersilia-os:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Abellegese authored Nov 23, 2024
2 parents 098e010 + 98618b7 commit c67d9f3
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
name: Fetch model
command: |
source activate ersilia
ersilia -v fetch eos0t01 --repo_path ./test/models/eos0t01
ersilia -v fetch eos0t01 --from_dir ./test/models/eos0t01
- run:
name: Delete model
command: |
Expand Down
10 changes: 3 additions & 7 deletions ersilia/cli/commands/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def _fetch(mf, model_id):
"an EOS folder, then packed to a BentoML bundle",
)
@click.argument("model", type=click.STRING)
@click.option("--repo_path", "-r", default=None, type=click.STRING)
@click.option("--mode", "-m", default=None, type=click.STRING)
@click.option("--dockerize/--not-dockerize", default=False)
@click.option(
Expand Down Expand Up @@ -78,7 +77,6 @@ def _fetch(mf, model_id):
)
def fetch(
model,
repo_path,
mode,
dockerize,
overwrite,
Expand All @@ -93,11 +91,9 @@ def fetch(
):
if with_bentoml and with_fastapi:
raise Exception("Cannot use both BentoML and FastAPI")
if repo_path is not None:
mdl = ModelBase(repo_path=repo_path)
elif from_dir is not None:

if from_dir is not None:
mdl = ModelBase(repo_path=from_dir)
repo_path = from_dir
else:
mdl = ModelBase(model_id_or_slug=model)
model_id = mdl.model_id
Expand All @@ -106,7 +102,7 @@ def fetch(
fg="blue",
)
mf = ModelFetcher(
repo_path=repo_path,
repo_path=from_dir,
mode=mode,
dockerize=dockerize,
overwrite=overwrite,
Expand Down
4 changes: 3 additions & 1 deletion ersilia/utils/identifiers/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def is_key_header(self, h):
return h.lower() in self.key_header_synonyms

def _is_smiles(self, text):
if not isinstance(text, str) or not text.strip():
return False
if self.Chem is None:
return asyncio.run(self._process_pubchem_inchikey(text)) is not None
else:
Expand Down Expand Up @@ -182,7 +184,7 @@ async def process_smiles(self, smiles, semaphore, session, result_list):
logger.info("Inchikey converted using PUBCHEM")

if inchikey is None:
inchikey = self._nci_smiles_to_inchikey(smiles)
inchikey = self._nci_smiles_to_inchikey(session, smiles)
if inchikey:
logger.info("Inchikey converted using NCI")

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ docker = "^7.1.0"
boto3 = "^1.28.40"
requests = "^2.31.0"
numpy = "<=1.26.4"
setuptools = "^65.0.0" # added to fix the issue with setuptools
setuptools = "^70.0.0" # added to fix the issue with setuptools
isaura = { version = "0.1", optional = true }
aiofiles = "<=24.1.0"
aiohttp = "<=3.10.9"
aiohttp = ">=3.10.11"
nest_asyncio = "<=1.6.0"
pytest = { version = "^7.4.0", optional = true }
pytest-asyncio = { version = "<=0.24.0", optional = true }
Expand Down
100 changes: 81 additions & 19 deletions test/test_compound_identifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

from ersilia.default import UNPROCESSABLE_INPUT
import pytest
from ersilia.utils.identifiers.compound import CompoundIdentifier
from unittest.mock import patch
Expand Down Expand Up @@ -38,29 +40,89 @@ def test_is_inchikey_positive(compound_identifier, inchikey):
"""Test that valid InChIKeys return True."""
assert compound_identifier._is_inchikey(inchikey) is True

@pytest.fixture(params=[True, False], ids=["Chem_None", "Chem_Not_None"])
def compound_identifier(request):
"""Fixture that initializes CompoundIdentifier with or without RDKit."""
return CompoundIdentifier(local=request.param)

@pytest.mark.parametrize("inchikey", [
"BSYNRYMUTXBXSQUHFFFAOYSA",
"BSYNRYMUTXBXSQ-UHFFFAOYSA-XY",
"12345678901234-1234567890-X",
"BSYNRYMUTXBXSQ_UHFFFAOYSA-N",
"BSYNRYMUTXBXSQ-UHFFFAOYSA"
@pytest.mark.parametrize("smiles, expected", [
("C", True),
("CCO", True)
])
def test_is_inchikey_negative(compound_identifier, inchikey):
"""Test that invalid InChIKeys return False."""
assert not compound_identifier._is_inchikey(inchikey)
def test_is_smiles_positive(compound_identifier, smiles, expected):
"""Test _is_smiles returns True for valid SMILES strings."""
if compound_identifier.Chem is None:
assert compound_identifier._is_smiles(smiles) == expected


def test_guess_type_with_inchikey(compound_identifier):
inchikey = "LFQSCWFLJHTTHZ-UHFFFAOYSA-N"
@pytest.mark.parametrize("smiles, expected", [
("invalid_smiles", False),
("", False)
])
def test_is_smiles_negative(compound_identifier, smiles, expected):
"""Test _is_smiles returns False for invalid or empty SMILES strings."""
assert compound_identifier._is_smiles(smiles) == expected

@pytest.mark.parametrize("inchikey, expected", [
("BQJCRHHNABKAKU-KBQPJGBKSA-N", True),
])
def test_is_inchikey_positive(inchikey, expected):
"""Test _is_inchikey returns True for valid InChIKey."""
assert CompoundIdentifier._is_inchikey(inchikey) == expected

@pytest.mark.parametrize("inchikey, expected", [
("invalid_inchikey", False),
("BQJCRHHNABKAKU-KBQPJGBKSA", False)
])
def test_is_inchikey_negative(inchikey, expected):
"""Test _is_inchikey returns False for invalid InChIKeys."""
assert CompoundIdentifier._is_inchikey(inchikey) == expected

@pytest.mark.parametrize("inchikey, expected", [
("BQJCRHHNABKAKU-KBQPJGBKSA-N", "inchikey"),
("ABCDEFGHIJKLMN-OPQRSTUVWX-Y", "inchikey"),
])
def test_guess_type_inchikey(compound_identifier, inchikey, expected):
"""Ensure guess_type correctly identifies valid InChIKeys."""
result = compound_identifier.guess_type(inchikey)
assert result == "inchikey"


@patch('ersilia.utils.identifiers.compound.CompoundIdentifier._pubchem_smiles_to_inchikey')
def test_is_smiles_positive_chem_none(mock_pubchem, compound_identifier):
compound_identifier.Chem = None
mock_pubchem.return_value = "InChIKey"
assert result == expected, f"Expected 'inchikey', but got '{result}' for input '{inchikey}'"

@pytest.mark.parametrize("smiles, expected", [
("C", "smiles"),
("CCO", "smiles"),
])
def test_guess_type_smiles(compound_identifier, smiles, expected):
"""Ensure guess_type correctly identifies valid SMILES strings."""
result = compound_identifier.guess_type(smiles)
assert result == expected, f"Expected 'smiles', but got '{result}' for input '{smiles}'"

@pytest.mark.parametrize("input_data, expected", [
(None, UNPROCESSABLE_INPUT),
(UNPROCESSABLE_INPUT, UNPROCESSABLE_INPUT),
])
def test_guess_type_unprocessable(compound_identifier, input_data, expected):
"""Ensure guess_type returns UNPROCESSABLE_INPUT for None or unprocessable inputs."""
result = compound_identifier.guess_type(input_data)
assert result == expected, f"Expected '{UNPROCESSABLE_INPUT}', but got '{result}'"

@pytest.mark.parametrize("whitespace_input, expected", [
("\n", UNPROCESSABLE_INPUT),
("\t", UNPROCESSABLE_INPUT),
(" ", UNPROCESSABLE_INPUT),
])
def test_guess_type_whitespace(compound_identifier, whitespace_input, expected):
"""Ensure guess_type returns UNPROCESSABLE_INPUT for whitespace-only input."""
result = compound_identifier.guess_type(whitespace_input)
assert result == expected, f"Expected '{UNPROCESSABLE_INPUT}' for input '{whitespace_input}'"

@pytest.mark.parametrize("non_char_input, expected", [
(12345, UNPROCESSABLE_INPUT),
(3.14, UNPROCESSABLE_INPUT),
("𠜎𠜱𡿺𠬠", UNPROCESSABLE_INPUT),
])
def test_guess_type_non_character(compound_identifier, non_char_input, expected):
"""Ensure guess_type returns UNPROCESSABLE_INPUT for non-character input."""
result = compound_identifier.guess_type(non_char_input)
assert result == expected, f"Expected '{UNPROCESSABLE_INPUT}' for input '{non_char_input}'"

# Test with a valid SMILES input
smiles_string = 'CCO' #Ethanol SMILES
Expand Down

0 comments on commit c67d9f3

Please sign in to comment.