diff --git a/ersilia/cli/commands/serve.py b/ersilia/cli/commands/serve.py index af8953ba8..7a529edaa 100644 --- a/ersilia/cli/commands/serve.py +++ b/ersilia/cli/commands/serve.py @@ -7,7 +7,7 @@ from ..messages import ModelNotFound, ModelNotInStore from ...core.tracking import write_persistent_file from ...store.api import InferenceStoreApi -from ...store.utils import OutputSource +from ...store.utils import OutputSource, store_has_model def serve_cmd(): """Creates serve command""" @@ -22,8 +22,6 @@ def serve_cmd(): required=False, help=f"Get outputs from locally hosted model only ({OutputSource.LOCAL_ONLY}), \ from cloud precalculation store only ({OutputSource.CLOUD_ONLY})" - # or from cloud precalculation store first then locally hosted model for any inputs \ - # that haven't been precalculated ({OutputSource.CLOUD_FIRST})" ) @click.option("--lake/--no-lake", is_flag=True, default=True) @click.option("--docker/--no-docker", is_flag=True, default=False) @@ -50,17 +48,10 @@ def serve(model, output_source, lake, docker, port, track): else: service_class = None if OutputSource.is_cloud(output_source): - store = InferenceStoreApi(model_id=model) - if not store.has_model(): - # if output_source == OutputSource.CLOUD_ONLY: - ModelNotInStore(model).echo() - # echo( - # "Model {0} not found in inference store. Serving model with output-source={1}.".format(model, OutputSource.LOCAL_ONLY), - # fg="yellow" - # ) - # output_source = OutputSource.LOCAL_ONLY - else: + if store_has_model(model_id=model): echo("Model {0} found in inference store.".format(model)) + else: + ModelNotInStore(model).echo() mdl = ErsiliaModel( model, output_source=output_source, save_to_lake=lake, diff --git a/ersilia/core/model.py b/ersilia/core/model.py index 2bfd0466c..364296999 100644 --- a/ersilia/core/model.py +++ b/ersilia/core/model.py @@ -27,7 +27,7 @@ from .tracking import RunTracker, create_persistent_file from ..io.readers.file import FileTyper, TabularFileReader from ..store.api import InferenceStoreApi -from ..store.utils import OutputSource +from ..store.utils import OutputSource, store_has_model from ..utils.exceptions_utils.api_exceptions import ApiSpecifiedOutputError from ..default import FETCHED_MODELS_FILENAME, MODEL_SIZE_FILE, CARD_FILE, EOS @@ -368,20 +368,13 @@ def api( self, api_name=None, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE ): if OutputSource.is_cloud(self.output_source): - # here (send to store and get back results dict + missing inputs list) - store = InferenceStoreApi(model_id=self.model_id) - print(self.model_id) - print(input) - if store.has_model(): + if store_has_model(model_id=self.model_id): + store = InferenceStoreApi(model_id=self.model_id) result_from_store = store.get_precalculations(input) - print(result_from_store) - - # if self.output_source == OutputSource.CLOUD_ONLY: - return "this is the result returned to CLI" # should missing keys be returned too in a file/message? - - # if self.output_source == OutputSource.CLOUD_FIRST and len(missing_keys): - # pass # save missing keys to file to use in subsequent steps below - if self._do_cache_splits(input=input, output=output): + else: + result_from_store = "No precalculations found in store." + return result_from_store + elif self._do_cache_splits(input=input, output=output): splitted_inputs = self.tfr.split_in_cache() self.logger.debug("Split inputs:") self.logger.debug(" ".join(splitted_inputs)) diff --git a/ersilia/default.py b/ersilia/default.py index 8ba1198bb..483777701 100644 --- a/ersilia/default.py +++ b/ersilia/default.py @@ -93,6 +93,7 @@ AIRTABLE_MODEL_HUB_VIEW_URL = "https://airtable.com/shrNc3sTtTA3QeEZu" S3_BUCKET_URL = "https://ersilia-models.s3.eu-central-1.amazonaws.com" S3_BUCKET_URL_ZIP = "https://ersilia-models-zipped.s3.eu-central-1.amazonaws.com" +INFERENCE_STORE_API_URL = "https://5x2fkcjtei.execute-api.eu-central-1.amazonaws.com/dev/precalculations" # EOS conda _resolve_script = "conda_env_resolve.py" diff --git a/ersilia/serve/standard_api.py b/ersilia/serve/standard_api.py index 3d914d15f..f60b2d042 100644 --- a/ersilia/serve/standard_api.py +++ b/ersilia/serve/standard_api.py @@ -6,7 +6,7 @@ from .. import ErsiliaBase from ..store.api import InferenceStoreApi -from ..store.utils import OutputSource +from ..store.utils import OutputSource, store_has_model from ..default import ( EXAMPLE_STANDARD_INPUT_CSV_FILENAME, EXAMPLE_STANDARD_OUTPUT_CSV_FILENAME, @@ -226,24 +226,13 @@ def serialize_to_csv(self, input_data, result, output_data): def post(self, input, output, output_source=OutputSource.LOCAL_ONLY): input_data = self.serialize_to_json(input) if OutputSource.is_cloud(output_source): - # - # HERE (send to store and get back results dict + missing inputs list) - # - store = InferenceStoreApi(model_id=self.model_id) - # print(self.model_id) - # print(input_data) - if store.has_model(): + if store_has_model(model_id=self.model_id): + store = InferenceStoreApi(model_id=self.model_id) result_from_store = store.get_precalculations(input_data) - # print(result_from_store) - # print(missing_keys) - - # if output_source == OutputSource.CLOUD_ONLY: - output_data = self.serialize_to_csv(input_data, result_from_store, output) + output_data = self.serialize_to_csv(input_data, result_from_store, output) + else: + output_data = "No precalculations found in store." return output_data # should missing keys be returned too in a file/message? - - # if output_source == OutputSource.CLOUD_FIRST and len(missing_keys): - # input_data = missing_keys - url = "{0}/{1}".format(self.url, self.api_name) response = requests.post(url, json=input_data) if response.status_code == 200: diff --git a/ersilia/store/api.py b/ersilia/store/api.py index c500038aa..ac9008d65 100644 --- a/ersilia/store/api.py +++ b/ersilia/store/api.py @@ -1,81 +1,80 @@ from ersilia.core.base import ErsiliaBase from ersilia.io.input import GenericInputAdapter -from ersilia.store.utils import InferenceStoreApiPayload -import requests, uuid +from ersilia.store.utils import InferenceStoreApiPayload, store_has_model +from ersilia.default import INFERENCE_STORE_API_URL +import requests +import tempfile +import uuid -INFERENCE_STORE_UPLOAD_API_URL = "" class InferenceStoreApi(ErsiliaBase): def __init__(self, model_id): ErsiliaBase.__init__(self) self.model_id = model_id - self._upload_url = None + self.request_id = None self.input_adapter = GenericInputAdapter(model_id=self.model_id) - - ### Call Lambda function to get upload URL ### - def _get_presigned_url(self) -> str: - presigned_url = requests.get( - INFERENCE_STORE_API_URL, params={"model_id": self.model_id} - ) - # try: - # upload_url = requests.get( - # INFERENCE_STORE_API_URL, params={"model_id": self.model_id} - # ) - # except: - # upload_url = None - # return upload_url + def _generate_request_id(self) -> str: + self.request_id = str(uuid.uuid4()) + return self.request_id - def has_model(self) -> bool: - # GET request to check if model exists - - # if self._upload_url is None: # TODO: maybe another condition here to check if URL has expired - # self._upload_url = self._get_upload_url() - # return self._upload_url is not Non + def _get_presigned_url(self) -> str: + response = requests.get( + INFERENCE_STORE_API_URL + "/upload-destination", + params={ + "modelid": self.model_id, + "requestid": self.request_id + }, + timeout=60 + ) + presigned_url_response = response.json() + return presigned_url_response - ### TODO: Send payload to S3 ### - def _post_inputs(self, presigned_url) -> str: - adapted_input_generator = self.input_adapter.adapt_one_by_one(input) + def _post_inputs(self, inputs, presigned_url_response) -> str: + adapted_input_generator = self.input_adapter.adapt_one_by_one(inputs) smiles_list = [val["input"] for val in adapted_input_generator] payload = InferenceStoreApiPayload(model=self.model_id, inputs=smiles_list) - return "" - ### TODO: Get precalculations from S3 ### - def _get_outputs(self, model_id=None, request_id=None) -> dict: - return {} + with tempfile.NamedTemporaryFile() as input_file: + input_file.write(payload.model_dump()) + inputs = input_file.name - def get_precalculations(self, input: str): + with open(inputs, 'rb') as f: + files = {'file': (inputs, f)} + presigned_url = presigned_url_response.get('url') + presigned_url_data = presigned_url_response.get('fields') + http_response = requests.post( + presigned_url, + data=presigned_url_data, + files=files, + timeout=60 + ) - def generate_request_id(): - return str(uuid.uuid4()) + return http_response - # print("-----------------------") - # print("GOT MODEL ID") - # print(model_id) - # print("GOT INPUT") - # print(input) - # print("ADAPTED INPUT") - - # print(payload.model_id) - # print(payload.inputs) - # print(payload.model_dump()) - - request_id = generate_request_id() - - presigned_url = self._get_presigned_url() + def _get_outputs(self) -> str: + response = requests.post( + INFERENCE_STORE_API_URL + "/precalculations", + params={ + "requestid": self.request_id + }, + timeout=60 + ) + return response.text - response = self._post_inputs(presigned_url) - if response.status_code == 200: + def get_precalculations(self, inputs): + if store_has_model(self.model_id): + self._generate_request_id() + presigned_url = self._get_presigned_url() + + response = self._post_inputs(inputs, presigned_url) + if response.status_code == 204: print('File uploaded successfully') else: - print('Failed to upload file:', response.status_code, response.text) - - precalculations = self._get_outputs(self.model_id, request_id) + return f'Failed to upload file: {response.status_code} error ({response.text})' - return precalculations + precalculations = self._get_outputs() -if __name__ == "__main__": - abc = 'https://bzr6zxw1k0.execute-api.ap-southeast-2.amazonaws.com/4-8-24' - response = requests.get(abc, params={"model_id": 'ey1829ey', "request_id": '12e1-2e-12e12e'}) - print(response.json()) \ No newline at end of file + # TODO: should missing keys be returned too in a file/message? + return precalculations # this is the result returned to the CLI diff --git a/ersilia/store/utils.py b/ersilia/store/utils.py index c0fac70f4..f1c2b36fe 100644 --- a/ersilia/store/utils.py +++ b/ersilia/store/utils.py @@ -1,11 +1,11 @@ +from ersilia.default import INFERENCE_STORE_API_URL +import requests class OutputSource(): LOCAL_ONLY = "local-only" CLOUD_ONLY = "cloud-only" - # CLOUD_FIRST = "cloud-first" ALL = [ LOCAL_ONLY, CLOUD_ONLY, - #CLOUD_FIRST ] @classmethod @@ -14,12 +14,24 @@ def is_local(cls, option): @classmethod def is_cloud(cls, option): - return option in ( - cls.CLOUD_ONLY, - #cls.CLOUD_FIRST - ) + return option == cls.CLOUD_ONLY from pydantic import BaseModel class InferenceStoreApiPayload(BaseModel): model: str - inputs: list[str] = [] # validation error if inputs is None e.g. if inchi->smiles fails \ No newline at end of file + inputs: list[str] = [] # validation error if inputs is None e.g. if inchi->smiles fails + + +def store_has_model(model_id: str) -> bool: + response = requests.get( + INFERENCE_STORE_API_URL + "/model", + params={ + "modelid": model_id + }, + timeout=60 + ) + if response.status_code == 200: + print("Model found in inference store: ", response.json()) + return True + print("Model not found in inference store") + return False