Skip to content

Commit

Permalink
stripping out some order concerns out of client (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
grmpflh27 authored Jan 14, 2025
1 parent 83aa6c5 commit 4118c6e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 75 deletions.
109 changes: 35 additions & 74 deletions capella_console_client/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging
import sys

from datetime import datetime
from typing import List, Dict, Any, Union, Optional, no_type_check, Tuple
from collections import defaultdict
from pathlib import Path
import tempfile

import dateutil.parser

from capella_console_client.config import CONSOLE_API_URL
from capella_console_client.session import CapellaConsoleSession
from capella_console_client.logconf import logger
Expand All @@ -33,12 +30,14 @@
from capella_console_client.repeat_request import create_repeat_request
from capella_console_client.validate import (
_validate_uuid,
_validate_uuids,
_validate_stac_id_or_stac_items,
_validate_and_filter_product_types,
_validate_and_filter_asset_types,
_validate_and_filter_stac_ids,
)
from capella_console_client.sort import _sort_stac_items
from capella_console_client.order import get_order, get_non_expired_orders


class CapellaConsoleClient:
Expand Down Expand Up @@ -266,39 +265,34 @@ def list_orders(self, *order_ids: Optional[str], is_active: Optional[bool] = Fal
list orders
Args:
order_id: list only specific orders (variadic, specify multiple)
order_id: list only specific orders (variadic, specify multiple) - if omitted all orders are listed
is_active: list only active (non-expired) orders
Returns:
List[Dict[str, Any]]: metadata of orders
"""
orders = []

if order_ids:
for order_id in order_ids:
_validate_uuid(order_id)
_validate_uuids(order_ids)

# prefilter non expired
if is_active:
orders = _get_non_expired_orders(session=self._sesh)
orders = get_non_expired_orders(session=self._sesh)
if order_ids:
set_order_ids = set(order_ids)
orders = [o for o in orders if o["orderId"] in set_order_ids]
else:
# list all orders
if not order_ids:
params = {
"customerId": self._sesh.customer_id,
}
resp = self._sesh.get("/orders", params=params)
orders = resp.json()

# list specific orders
else:
for order_id in order_ids:
resp = self._sesh.get(f"/orders/{order_id}")
orders.append(resp.json())
orders = [o for o in orders if o["orderId"] in set(order_ids)]
return orders

# list specific orders
if order_ids:
orders = [self._sesh.get(f"/orders/{order_id}").json() for order_id in order_ids]
return orders

# list all orders of customer
params = {
"customerId": self._sesh.customer_id,
}
resp = self._sesh.get("/orders", params=params)
orders = resp.json()
return orders

def get_stac_items_of_order(self, order_id: str, ids_only: bool = False) -> Union[List[str], SearchResult]:
Expand All @@ -324,7 +318,6 @@ def review_order(
contract_id: Optional[str] = None,
) -> Dict[str, Any]:
stac_ids = _validate_stac_id_or_stac_items(stac_ids, items)

logger.info(f"reviewing order for {', '.join(stac_ids)}")

stac_items = items # type: ignore
Expand Down Expand Up @@ -373,22 +366,29 @@ def submit_order(
stac_ids = _validate_stac_id_or_stac_items(stac_ids, items)

if check_active_orders:
order_id = self._find_active_order(stac_ids)
order_id = get_order(session=self._sesh, stac_ids=stac_ids)
if order_id is not None:
logger.info(f"found active order {order_id}")
logger.info(f"found existing order {order_id} containing all requested stac ids")
return order_id

if stac_ids and not omit_search:
stac_items = self.search(ids=stac_ids)
else:
if omit_search and not items:
logger.warning("setting omit_search=True only works in combination providing items instead of stac_ids")
def _get_stac_items():
if stac_ids and not omit_search:
stac_items = self.search(ids=stac_ids)
else:
stac_items = items # type: ignore
if omit_search and not items:
logger.warning(
"setting omit_search=True only works in combination providing items instead of stac_ids"
)
stac_items = self.search(ids=stac_ids)
else:
stac_items = items # type: ignore

if not stac_items:
raise NoValidStacIdsError(f"No valid STAC IDs in {', '.join(stac_ids)}")
if not stac_items:
raise NoValidStacIdsError(f"No valid STAC IDs in {', '.join(stac_ids)}")

return stac_items

stac_items = _get_stac_items()

if not omit_review:
self.review_order(items=stac_items, contract_id=contract_id)
Expand Down Expand Up @@ -420,25 +420,6 @@ def _construct_order_payload(self, stac_items, contract_id: Optional[str] = None

return payload

def _find_active_order(self, stac_ids: List[str]) -> Optional[str]:
"""
find active order containing ALL specified `stac_ids`
Args:
stac_ids: STAC IDs that active order should include
"""
order_id = None
active_orders = _get_non_expired_orders(session=self._sesh)
if not active_orders:
return None

for ord in active_orders:
granules = set([i["granuleId"] for i in ord["items"]])
if granules.issuperset(stac_ids):
order_id = ord["orderId"]
break
return order_id

def get_presigned_items(
self,
order_id: str,
Expand Down Expand Up @@ -883,23 +864,3 @@ def search(self, **kwargs) -> SearchResult:
"""
search = StacSearch(session=self._sesh, **kwargs)
return search.fetch_all()


def _get_non_expired_orders(session: CapellaConsoleSession) -> List[Dict[str, Any]]:
params = {"customerId": session.customer_id}
res = session.get("/orders", params=params)

all_orders = res.json()

ordered_by_exp_date = sorted(all_orders, key=lambda x: x["expirationDate"])
now = datetime.utcnow()

active_orders = []
while ordered_by_exp_date:
cur = ordered_by_exp_date.pop()
cur_exp_date = dateutil.parser.parse(cur["expirationDate"], ignoretz=True)
if cur_exp_date < now:
break
active_orders.append(cur)

return active_orders
45 changes: 45 additions & 0 deletions capella_console_client/order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datetime import datetime, timezone
import dateutil.parser
from typing import List, Optional, Dict, Any

from capella_console_client.session import CapellaConsoleSession


def get_order(session: CapellaConsoleSession, stac_ids: List[str]) -> Optional[str]:
"""
find active order containing ALL specified `stac_ids`
Args:
stac_ids: STAC IDs that active order should include
"""
order_id = None
active_orders = get_non_expired_orders(session=session)
if not active_orders:
return None

for ord in active_orders:
granules = set([i["granuleId"] for i in ord["items"]])
if granules.issuperset(stac_ids):
order_id = ord["orderId"]
break
return order_id


def get_non_expired_orders(session: CapellaConsoleSession) -> List[Dict[str, Any]]:
params = {"customerId": session.customer_id}
res = session.get("/orders", params=params)

all_orders = res.json()

ordered_by_exp_date = sorted(all_orders, key=lambda x: x["expirationDate"])
now = datetime.now(tz=timezone.utc)

active_orders = []
while ordered_by_exp_date:
cur = ordered_by_exp_date.pop()
cur_exp_date = dateutil.parser.parse(cur["expirationDate"], ignoretz=False)
if cur_exp_date < now:
break
active_orders.append(cur)

return active_orders
6 changes: 6 additions & 0 deletions capella_console_client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def _validate_uuid(uuid_str: str) -> None:
raise ValueError(f"{uuid_str} is not a valid uuid: {e}")


@no_type_check
def _validate_uuids(uuid_strs: List[str]):
for uuid_str in uuid_strs:
_validate_uuid(uuid_str)


def _validate_stac_id_or_stac_items(
stac_ids: Optional[List[str]] = None,
items: Union[Optional[List[Dict[str, Any]]], SearchResult] = None,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def assert_all_responses_were_requested() -> bool:
@pytest.fixture
def disable_validate_uuid(monkeypatch):
monkeypatch.setattr(capella_client_module, "_validate_uuid", lambda x: None)
monkeypatch.setattr(capella_client_module, "_validate_uuids", lambda x: None)
monkeypatch.setattr(tasking_modules, "_validate_uuid", lambda x: None)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_list_no_active_orders(order_client):
def test_list_active_orders_with_order_ids(test_client, monkeypatch, disable_validate_uuid):
monkeypatch.setattr(
capella_client_module,
"_get_non_expired_orders",
"get_non_expired_orders",
lambda session: get_mock_responses("/orders"),
)

Expand Down

0 comments on commit 4118c6e

Please sign in to comment.