-
-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c1820c
commit 87a586c
Showing
6 changed files
with
93 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,80 @@ | ||
from ersilia.core.base import ErsiliaBase | ||
from ersilia.io.input import GenericInputAdapter | ||
from ersilia.store.utils import InferenceStoreApiPayload | ||
import requests, uuid | ||
from ersilia.store.utils import InferenceStoreApiPayload, store_has_model | ||
from ersilia.default import INFERENCE_STORE_API_URL | ||
import requests | ||
import tempfile | ||
import uuid | ||
|
||
INFERENCE_STORE_UPLOAD_API_URL = "" | ||
|
||
class InferenceStoreApi(ErsiliaBase): | ||
|
||
def __init__(self, model_id): | ||
ErsiliaBase.__init__(self) | ||
self.model_id = model_id | ||
self._upload_url = None | ||
self.request_id = None | ||
self.input_adapter = GenericInputAdapter(model_id=self.model_id) | ||
|
||
### Call Lambda function to get upload URL ### | ||
def _get_presigned_url(self) -> str: | ||
presigned_url = requests.get( | ||
INFERENCE_STORE_API_URL, params={"model_id": self.model_id} | ||
) | ||
|
||
# try: | ||
# upload_url = requests.get( | ||
# INFERENCE_STORE_API_URL, params={"model_id": self.model_id} | ||
# ) | ||
# except: | ||
# upload_url = None | ||
# return upload_url | ||
def _generate_request_id(self) -> str: | ||
self.request_id = str(uuid.uuid4()) | ||
return self.request_id | ||
|
||
def has_model(self) -> bool: | ||
# GET request to check if model exists | ||
|
||
# if self._upload_url is None: # TODO: maybe another condition here to check if URL has expired | ||
# self._upload_url = self._get_upload_url() | ||
# return self._upload_url is not Non | ||
def _get_presigned_url(self) -> str: | ||
response = requests.get( | ||
INFERENCE_STORE_API_URL + "/upload-destination", | ||
params={ | ||
"modelid": self.model_id, | ||
"requestid": self.request_id | ||
}, | ||
timeout=60 | ||
) | ||
presigned_url_response = response.json() | ||
return presigned_url_response | ||
|
||
### TODO: Send payload to S3 ### | ||
def _post_inputs(self, presigned_url) -> str: | ||
adapted_input_generator = self.input_adapter.adapt_one_by_one(input) | ||
def _post_inputs(self, inputs, presigned_url_response) -> str: | ||
adapted_input_generator = self.input_adapter.adapt_one_by_one(inputs) | ||
smiles_list = [val["input"] for val in adapted_input_generator] | ||
payload = InferenceStoreApiPayload(model=self.model_id, inputs=smiles_list) | ||
return "" | ||
|
||
### TODO: Get precalculations from S3 ### | ||
def _get_outputs(self, model_id=None, request_id=None) -> dict: | ||
return {} | ||
with tempfile.NamedTemporaryFile() as input_file: | ||
input_file.write(payload.model_dump()) | ||
inputs = input_file.name | ||
|
||
def get_precalculations(self, input: str): | ||
with open(inputs, 'rb') as f: | ||
files = {'file': (inputs, f)} | ||
presigned_url = presigned_url_response.get('url') | ||
presigned_url_data = presigned_url_response.get('fields') | ||
http_response = requests.post( | ||
presigned_url, | ||
data=presigned_url_data, | ||
files=files, | ||
timeout=60 | ||
) | ||
|
||
def generate_request_id(): | ||
return str(uuid.uuid4()) | ||
return http_response | ||
|
||
# print("-----------------------") | ||
# print("GOT MODEL ID") | ||
# print(model_id) | ||
# print("GOT INPUT") | ||
# print(input) | ||
# print("ADAPTED INPUT") | ||
|
||
# print(payload.model_id) | ||
# print(payload.inputs) | ||
# print(payload.model_dump()) | ||
|
||
request_id = generate_request_id() | ||
|
||
presigned_url = self._get_presigned_url() | ||
def _get_outputs(self) -> str: | ||
response = requests.post( | ||
INFERENCE_STORE_API_URL + "/precalculations", | ||
params={ | ||
"requestid": self.request_id | ||
}, | ||
timeout=60 | ||
) | ||
return response.text | ||
|
||
response = self._post_inputs(presigned_url) | ||
if response.status_code == 200: | ||
def get_precalculations(self, inputs): | ||
if store_has_model(self.model_id): | ||
self._generate_request_id() | ||
presigned_url = self._get_presigned_url() | ||
|
||
response = self._post_inputs(inputs, presigned_url) | ||
if response.status_code == 204: | ||
print('File uploaded successfully') | ||
else: | ||
print('Failed to upload file:', response.status_code, response.text) | ||
|
||
precalculations = self._get_outputs(self.model_id, request_id) | ||
return f'Failed to upload file: {response.status_code} error ({response.text})' | ||
|
||
return precalculations | ||
precalculations = self._get_outputs() | ||
|
||
if __name__ == "__main__": | ||
abc = 'https://bzr6zxw1k0.execute-api.ap-southeast-2.amazonaws.com/4-8-24' | ||
response = requests.get(abc, params={"model_id": 'ey1829ey', "request_id": '12e1-2e-12e12e'}) | ||
print(response.json()) | ||
# TODO: should missing keys be returned too in a file/message? | ||
return precalculations # this is the result returned to the CLI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters