Skip to content

Commit

Permalink
Remove cloud_first and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-kondal committed Aug 24, 2024
1 parent 0c1820c commit 87a586c
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 108 deletions.
17 changes: 4 additions & 13 deletions ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..messages import ModelNotFound, ModelNotInStore
from ...core.tracking import write_persistent_file
from ...store.api import InferenceStoreApi
from ...store.utils import OutputSource
from ...store.utils import OutputSource, store_has_model

def serve_cmd():
"""Creates serve command"""
Expand All @@ -22,8 +22,6 @@ def serve_cmd():
required=False,
help=f"Get outputs from locally hosted model only ({OutputSource.LOCAL_ONLY}), \
from cloud precalculation store only ({OutputSource.CLOUD_ONLY})"
# or from cloud precalculation store first then locally hosted model for any inputs \
# that haven't been precalculated ({OutputSource.CLOUD_FIRST})"
)
@click.option("--lake/--no-lake", is_flag=True, default=True)
@click.option("--docker/--no-docker", is_flag=True, default=False)
Expand All @@ -50,17 +48,10 @@ def serve(model, output_source, lake, docker, port, track):
else:
service_class = None
if OutputSource.is_cloud(output_source):
store = InferenceStoreApi(model_id=model)
if not store.has_model():
# if output_source == OutputSource.CLOUD_ONLY:
ModelNotInStore(model).echo()
# echo(
# "Model {0} not found in inference store. Serving model with output-source={1}.".format(model, OutputSource.LOCAL_ONLY),
# fg="yellow"
# )
# output_source = OutputSource.LOCAL_ONLY
else:
if store_has_model(model_id=model):
echo("Model {0} found in inference store.".format(model))
else:
ModelNotInStore(model).echo()
mdl = ErsiliaModel(
model,
output_source=output_source, save_to_lake=lake,
Expand Down
21 changes: 7 additions & 14 deletions ersilia/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .tracking import RunTracker, create_persistent_file
from ..io.readers.file import FileTyper, TabularFileReader
from ..store.api import InferenceStoreApi
from ..store.utils import OutputSource
from ..store.utils import OutputSource, store_has_model

from ..utils.exceptions_utils.api_exceptions import ApiSpecifiedOutputError
from ..default import FETCHED_MODELS_FILENAME, MODEL_SIZE_FILE, CARD_FILE, EOS
Expand Down Expand Up @@ -368,20 +368,13 @@ def api(
self, api_name=None, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE
):
if OutputSource.is_cloud(self.output_source):
# here (send to store and get back results dict + missing inputs list)
store = InferenceStoreApi(model_id=self.model_id)
print(self.model_id)
print(input)
if store.has_model():
if store_has_model(model_id=self.model_id):
store = InferenceStoreApi(model_id=self.model_id)
result_from_store = store.get_precalculations(input)
print(result_from_store)

# if self.output_source == OutputSource.CLOUD_ONLY:
return "this is the result returned to CLI" # should missing keys be returned too in a file/message?

# if self.output_source == OutputSource.CLOUD_FIRST and len(missing_keys):
# pass # save missing keys to file to use in subsequent steps below
if self._do_cache_splits(input=input, output=output):
else:
result_from_store = "No precalculations found in store."
return result_from_store
elif self._do_cache_splits(input=input, output=output):
splitted_inputs = self.tfr.split_in_cache()
self.logger.debug("Split inputs:")
self.logger.debug(" ".join(splitted_inputs))
Expand Down
1 change: 1 addition & 0 deletions ersilia/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
AIRTABLE_MODEL_HUB_VIEW_URL = "https://airtable.com/shrNc3sTtTA3QeEZu"
S3_BUCKET_URL = "https://ersilia-models.s3.eu-central-1.amazonaws.com"
S3_BUCKET_URL_ZIP = "https://ersilia-models-zipped.s3.eu-central-1.amazonaws.com"
INFERENCE_STORE_API_URL = "https://5x2fkcjtei.execute-api.eu-central-1.amazonaws.com/dev/precalculations"

# EOS conda
_resolve_script = "conda_env_resolve.py"
Expand Down
23 changes: 6 additions & 17 deletions ersilia/serve/standard_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .. import ErsiliaBase
from ..store.api import InferenceStoreApi
from ..store.utils import OutputSource
from ..store.utils import OutputSource, store_has_model
from ..default import (
EXAMPLE_STANDARD_INPUT_CSV_FILENAME,
EXAMPLE_STANDARD_OUTPUT_CSV_FILENAME,
Expand Down Expand Up @@ -226,24 +226,13 @@ def serialize_to_csv(self, input_data, result, output_data):
def post(self, input, output, output_source=OutputSource.LOCAL_ONLY):
input_data = self.serialize_to_json(input)
if OutputSource.is_cloud(output_source):
#
# HERE (send to store and get back results dict + missing inputs list)
#
store = InferenceStoreApi(model_id=self.model_id)
# print(self.model_id)
# print(input_data)
if store.has_model():
if store_has_model(model_id=self.model_id):
store = InferenceStoreApi(model_id=self.model_id)
result_from_store = store.get_precalculations(input_data)
# print(result_from_store)
# print(missing_keys)

# if output_source == OutputSource.CLOUD_ONLY:
output_data = self.serialize_to_csv(input_data, result_from_store, output)
output_data = self.serialize_to_csv(input_data, result_from_store, output)
else:
output_data = "No precalculations found in store."
return output_data # should missing keys be returned too in a file/message?

# if output_source == OutputSource.CLOUD_FIRST and len(missing_keys):
# input_data = missing_keys

url = "{0}/{1}".format(self.url, self.api_name)
response = requests.post(url, json=input_data)
if response.status_code == 200:
Expand Down
113 changes: 56 additions & 57 deletions ersilia/store/api.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,80 @@
from ersilia.core.base import ErsiliaBase
from ersilia.io.input import GenericInputAdapter
from ersilia.store.utils import InferenceStoreApiPayload
import requests, uuid
from ersilia.store.utils import InferenceStoreApiPayload, store_has_model
from ersilia.default import INFERENCE_STORE_API_URL
import requests
import tempfile
import uuid

INFERENCE_STORE_UPLOAD_API_URL = ""

class InferenceStoreApi(ErsiliaBase):

def __init__(self, model_id):
ErsiliaBase.__init__(self)
self.model_id = model_id
self._upload_url = None
self.request_id = None
self.input_adapter = GenericInputAdapter(model_id=self.model_id)

### Call Lambda function to get upload URL ###
def _get_presigned_url(self) -> str:
presigned_url = requests.get(
INFERENCE_STORE_API_URL, params={"model_id": self.model_id}
)

# try:
# upload_url = requests.get(
# INFERENCE_STORE_API_URL, params={"model_id": self.model_id}
# )
# except:
# upload_url = None
# return upload_url
def _generate_request_id(self) -> str:
self.request_id = str(uuid.uuid4())
return self.request_id

def has_model(self) -> bool:
# GET request to check if model exists

# if self._upload_url is None: # TODO: maybe another condition here to check if URL has expired
# self._upload_url = self._get_upload_url()
# return self._upload_url is not Non
def _get_presigned_url(self) -> str:
response = requests.get(
INFERENCE_STORE_API_URL + "/upload-destination",
params={
"modelid": self.model_id,
"requestid": self.request_id
},
timeout=60
)
presigned_url_response = response.json()
return presigned_url_response

### TODO: Send payload to S3 ###
def _post_inputs(self, presigned_url) -> str:
adapted_input_generator = self.input_adapter.adapt_one_by_one(input)
def _post_inputs(self, inputs, presigned_url_response) -> str:
adapted_input_generator = self.input_adapter.adapt_one_by_one(inputs)
smiles_list = [val["input"] for val in adapted_input_generator]
payload = InferenceStoreApiPayload(model=self.model_id, inputs=smiles_list)
return ""

### TODO: Get precalculations from S3 ###
def _get_outputs(self, model_id=None, request_id=None) -> dict:
return {}
with tempfile.NamedTemporaryFile() as input_file:
input_file.write(payload.model_dump())
inputs = input_file.name

def get_precalculations(self, input: str):
with open(inputs, 'rb') as f:
files = {'file': (inputs, f)}
presigned_url = presigned_url_response.get('url')
presigned_url_data = presigned_url_response.get('fields')
http_response = requests.post(
presigned_url,
data=presigned_url_data,
files=files,
timeout=60
)

def generate_request_id():
return str(uuid.uuid4())
return http_response

# print("-----------------------")
# print("GOT MODEL ID")
# print(model_id)
# print("GOT INPUT")
# print(input)
# print("ADAPTED INPUT")

# print(payload.model_id)
# print(payload.inputs)
# print(payload.model_dump())

request_id = generate_request_id()

presigned_url = self._get_presigned_url()
def _get_outputs(self) -> str:
response = requests.post(
INFERENCE_STORE_API_URL + "/precalculations",
params={
"requestid": self.request_id
},
timeout=60
)
return response.text

response = self._post_inputs(presigned_url)
if response.status_code == 200:
def get_precalculations(self, inputs):
if store_has_model(self.model_id):
self._generate_request_id()
presigned_url = self._get_presigned_url()

response = self._post_inputs(inputs, presigned_url)
if response.status_code == 204:
print('File uploaded successfully')
else:
print('Failed to upload file:', response.status_code, response.text)

precalculations = self._get_outputs(self.model_id, request_id)
return f'Failed to upload file: {response.status_code} error ({response.text})'

return precalculations
precalculations = self._get_outputs()

if __name__ == "__main__":
abc = 'https://bzr6zxw1k0.execute-api.ap-southeast-2.amazonaws.com/4-8-24'
response = requests.get(abc, params={"model_id": 'ey1829ey', "request_id": '12e1-2e-12e12e'})
print(response.json())
# TODO: should missing keys be returned too in a file/message?
return precalculations # this is the result returned to the CLI
26 changes: 19 additions & 7 deletions ersilia/store/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ersilia.default import INFERENCE_STORE_API_URL
import requests
class OutputSource():
LOCAL_ONLY = "local-only"
CLOUD_ONLY = "cloud-only"
# CLOUD_FIRST = "cloud-first"
ALL = [
LOCAL_ONLY,
CLOUD_ONLY,
#CLOUD_FIRST
]

@classmethod
Expand All @@ -14,12 +14,24 @@ def is_local(cls, option):

@classmethod
def is_cloud(cls, option):
return option in (
cls.CLOUD_ONLY,
#cls.CLOUD_FIRST
)
return option == cls.CLOUD_ONLY

from pydantic import BaseModel
class InferenceStoreApiPayload(BaseModel):
model: str
inputs: list[str] = [] # validation error if inputs is None e.g. if inchi->smiles fails
inputs: list[str] = [] # validation error if inputs is None e.g. if inchi->smiles fails


def store_has_model(model_id: str) -> bool:
response = requests.get(
INFERENCE_STORE_API_URL + "/model",
params={
"modelid": model_id
},
timeout=60
)
if response.status_code == 200:
print("Model found in inference store: ", response.json())
return True
print("Model not found in inference store")
return False

0 comments on commit 87a586c

Please sign in to comment.