diff --git a/capella_console_client/client.py b/capella_console_client/client.py index c24764d..de667b5 100644 --- a/capella_console_client/client.py +++ b/capella_console_client/client.py @@ -1,14 +1,11 @@ import logging import sys -from datetime import datetime from typing import List, Dict, Any, Union, Optional, no_type_check, Tuple from collections import defaultdict from pathlib import Path import tempfile -import dateutil.parser - from capella_console_client.config import CONSOLE_API_URL from capella_console_client.session import CapellaConsoleSession from capella_console_client.logconf import logger @@ -33,12 +30,14 @@ from capella_console_client.repeat_request import create_repeat_request from capella_console_client.validate import ( _validate_uuid, + _validate_uuids, _validate_stac_id_or_stac_items, _validate_and_filter_product_types, _validate_and_filter_asset_types, _validate_and_filter_stac_ids, ) from capella_console_client.sort import _sort_stac_items +from capella_console_client.order import get_order, get_non_expired_orders class CapellaConsoleClient: @@ -266,39 +265,34 @@ def list_orders(self, *order_ids: Optional[str], is_active: Optional[bool] = Fal list orders Args: - order_id: list only specific orders (variadic, specify multiple) + order_id: list only specific orders (variadic, specify multiple) - if omitted all orders are listed is_active: list only active (non-expired) orders Returns: List[Dict[str, Any]]: metadata of orders """ - orders = [] if order_ids: - for order_id in order_ids: - _validate_uuid(order_id) + _validate_uuids(order_ids) # prefilter non expired if is_active: - orders = _get_non_expired_orders(session=self._sesh) + orders = get_non_expired_orders(session=self._sesh) if order_ids: - set_order_ids = set(order_ids) - orders = [o for o in orders if o["orderId"] in set_order_ids] - else: - # list all orders - if not order_ids: - params = { - "customerId": self._sesh.customer_id, - } - resp = self._sesh.get("/orders", params=params) - orders = resp.json() - - # list specific orders - else: - for order_id in order_ids: - resp = self._sesh.get(f"/orders/{order_id}") - orders.append(resp.json()) + orders = [o for o in orders if o["orderId"] in set(order_ids)] + return orders + # list specific orders + if order_ids: + orders = [self._sesh.get(f"/orders/{order_id}").json() for order_id in order_ids] + return orders + + # list all orders of customer + params = { + "customerId": self._sesh.customer_id, + } + resp = self._sesh.get("/orders", params=params) + orders = resp.json() return orders def get_stac_items_of_order(self, order_id: str, ids_only: bool = False) -> Union[List[str], SearchResult]: @@ -324,7 +318,6 @@ def review_order( contract_id: Optional[str] = None, ) -> Dict[str, Any]: stac_ids = _validate_stac_id_or_stac_items(stac_ids, items) - logger.info(f"reviewing order for {', '.join(stac_ids)}") stac_items = items # type: ignore @@ -373,22 +366,29 @@ def submit_order( stac_ids = _validate_stac_id_or_stac_items(stac_ids, items) if check_active_orders: - order_id = self._find_active_order(stac_ids) + order_id = get_order(session=self._sesh, stac_ids=stac_ids) if order_id is not None: - logger.info(f"found active order {order_id}") + logger.info(f"found existing order {order_id} containing all requested stac ids") return order_id - if stac_ids and not omit_search: - stac_items = self.search(ids=stac_ids) - else: - if omit_search and not items: - logger.warning("setting omit_search=True only works in combination providing items instead of stac_ids") + def _get_stac_items(): + if stac_ids and not omit_search: stac_items = self.search(ids=stac_ids) else: - stac_items = items # type: ignore + if omit_search and not items: + logger.warning( + "setting omit_search=True only works in combination providing items instead of stac_ids" + ) + stac_items = self.search(ids=stac_ids) + else: + stac_items = items # type: ignore - if not stac_items: - raise NoValidStacIdsError(f"No valid STAC IDs in {', '.join(stac_ids)}") + if not stac_items: + raise NoValidStacIdsError(f"No valid STAC IDs in {', '.join(stac_ids)}") + + return stac_items + + stac_items = _get_stac_items() if not omit_review: self.review_order(items=stac_items, contract_id=contract_id) @@ -420,25 +420,6 @@ def _construct_order_payload(self, stac_items, contract_id: Optional[str] = None return payload - def _find_active_order(self, stac_ids: List[str]) -> Optional[str]: - """ - find active order containing ALL specified `stac_ids` - - Args: - stac_ids: STAC IDs that active order should include - """ - order_id = None - active_orders = _get_non_expired_orders(session=self._sesh) - if not active_orders: - return None - - for ord in active_orders: - granules = set([i["granuleId"] for i in ord["items"]]) - if granules.issuperset(stac_ids): - order_id = ord["orderId"] - break - return order_id - def get_presigned_items( self, order_id: str, @@ -883,23 +864,3 @@ def search(self, **kwargs) -> SearchResult: """ search = StacSearch(session=self._sesh, **kwargs) return search.fetch_all() - - -def _get_non_expired_orders(session: CapellaConsoleSession) -> List[Dict[str, Any]]: - params = {"customerId": session.customer_id} - res = session.get("/orders", params=params) - - all_orders = res.json() - - ordered_by_exp_date = sorted(all_orders, key=lambda x: x["expirationDate"]) - now = datetime.utcnow() - - active_orders = [] - while ordered_by_exp_date: - cur = ordered_by_exp_date.pop() - cur_exp_date = dateutil.parser.parse(cur["expirationDate"], ignoretz=True) - if cur_exp_date < now: - break - active_orders.append(cur) - - return active_orders diff --git a/capella_console_client/order.py b/capella_console_client/order.py new file mode 100644 index 0000000..47f91a1 --- /dev/null +++ b/capella_console_client/order.py @@ -0,0 +1,45 @@ +from datetime import datetime, timezone +import dateutil.parser +from typing import List, Optional, Dict, Any + +from capella_console_client.session import CapellaConsoleSession + + +def get_order(session: CapellaConsoleSession, stac_ids: List[str]) -> Optional[str]: + """ + find active order containing ALL specified `stac_ids` + + Args: + stac_ids: STAC IDs that active order should include + """ + order_id = None + active_orders = get_non_expired_orders(session=session) + if not active_orders: + return None + + for ord in active_orders: + granules = set([i["granuleId"] for i in ord["items"]]) + if granules.issuperset(stac_ids): + order_id = ord["orderId"] + break + return order_id + + +def get_non_expired_orders(session: CapellaConsoleSession) -> List[Dict[str, Any]]: + params = {"customerId": session.customer_id} + res = session.get("/orders", params=params) + + all_orders = res.json() + + ordered_by_exp_date = sorted(all_orders, key=lambda x: x["expirationDate"]) + now = datetime.now(tz=timezone.utc) + + active_orders = [] + while ordered_by_exp_date: + cur = ordered_by_exp_date.pop() + cur_exp_date = dateutil.parser.parse(cur["expirationDate"], ignoretz=False) + if cur_exp_date < now: + break + active_orders.append(cur) + + return active_orders diff --git a/capella_console_client/validate.py b/capella_console_client/validate.py index 952e108..11824bd 100644 --- a/capella_console_client/validate.py +++ b/capella_console_client/validate.py @@ -21,6 +21,12 @@ def _validate_uuid(uuid_str: str) -> None: raise ValueError(f"{uuid_str} is not a valid uuid: {e}") +@no_type_check +def _validate_uuids(uuid_strs: List[str]): + for uuid_str in uuid_strs: + _validate_uuid(uuid_str) + + def _validate_stac_id_or_stac_items( stac_ids: Optional[List[str]] = None, items: Union[Optional[List[Dict[str, Any]]], SearchResult] = None, diff --git a/tests/conftest.py b/tests/conftest.py index a7b9426..8a3e1fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,7 @@ def assert_all_responses_were_requested() -> bool: @pytest.fixture def disable_validate_uuid(monkeypatch): monkeypatch.setattr(capella_client_module, "_validate_uuid", lambda x: None) + monkeypatch.setattr(capella_client_module, "_validate_uuids", lambda x: None) monkeypatch.setattr(tasking_modules, "_validate_uuid", lambda x: None) diff --git a/tests/test_order.py b/tests/test_order.py index dabab1e..16a7c05 100644 --- a/tests/test_order.py +++ b/tests/test_order.py @@ -38,7 +38,7 @@ def test_list_no_active_orders(order_client): def test_list_active_orders_with_order_ids(test_client, monkeypatch, disable_validate_uuid): monkeypatch.setattr( capella_client_module, - "_get_non_expired_orders", + "get_non_expired_orders", lambda session: get_mock_responses("/orders"), )