diff --git a/requirements.txt b/requirements.txt index ef8c286..6cf6c1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ amplitude_analytics~=1.1.1 dataclasses-json~=0.6.7 +sseclient-py~=1.8.0 diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index e734127..6532450 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -2,13 +2,17 @@ from typing import Optional import threading +from ..flag.flag_config_updater import FlagConfigPoller, FlagConfigStreamer, FlagConfigUpdaterFallbackRetryWrapper from ..local.config import LocalEvaluationConfig from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import CohortStorage -from ..flag.flag_config_api import FlagConfigApi +from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller -from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags +from ..util.flag_config import get_all_cohort_ids_from_flags + +DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000 +DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000 class DeploymentRunner: @@ -16,6 +20,7 @@ def __init__( self, config: LocalEvaluationConfig, flag_config_api: FlagConfigApi, + flag_config_stream_api: Optional[FlagConfigStreamApi], flag_config_storage: FlagConfigStorage, cohort_storage: CohortStorage, logger: logging.Logger, @@ -27,7 +32,22 @@ def __init__( self.cohort_storage = cohort_storage self.cohort_loader = cohort_loader self.lock = threading.Lock() - self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper( + FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, config, logger), + None, + 0, 0, config.flag_config_polling_interval_millis, 0, + logger + ) + if flag_config_stream_api: + self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper( + FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger), + self.flag_updater, + DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS, DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS, + config.flag_config_polling_interval_millis, 0, + logger + ) + + self.cohort_poller = None if self.cohort_loader: self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000, self.__update_cohorts) @@ -35,63 +55,14 @@ def __init__( def start(self): with self.lock: - self.__update_flag_configs() - self.flag_poller.start() + self.flag_updater.start(None) if self.cohort_loader: self.cohort_poller.start() def stop(self): - self.flag_poller.stop() - - def __periodic_flag_update(self): - try: - self.__update_flag_configs() - except Exception as e: - self.logger.warning(f"Error while updating flags: {e}") - - def __update_flag_configs(self): - try: - flag_configs = self.flag_config_api.get_flag_configs() - except Exception as e: - self.logger.warning(f'Failed to fetch flag configs: {e}') - raise e - - flag_keys = {flag.key for flag in flag_configs} - self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) - - if not self.cohort_loader: - for flag_config in flag_configs: - self.logger.debug(f"Putting non-cohort flag {flag_config.key}") - self.flag_config_storage.put_flag_config(flag_config) - return - - new_cohort_ids = set() - for flag_config in flag_configs: - new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) - - existing_cohort_ids = self.cohort_storage.get_cohort_ids() - cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - - # download all new cohorts - try: - self.cohort_loader.download_cohorts(cohort_ids_to_download).result() - except Exception as e: - self.logger.warning(f"Error while downloading cohorts: {e}") - - # get updated set of cohort ids - updated_cohort_ids = self.cohort_storage.get_cohort_ids() - # iterate through new flag configs and check if their required cohorts exist - for flag_config in flag_configs: - cohort_ids = get_all_cohort_ids_from_flag(flag_config) - self.logger.debug(f"Storing flag {flag_config.key}") - self.flag_config_storage.put_flag_config(flag_config) - missing_cohorts = cohort_ids - updated_cohort_ids - if missing_cohorts: - self.logger.warning(f"Flag {flag_config.key} - failed to load cohorts: {missing_cohorts}") - - # delete unused cohorts - self._delete_unused_cohorts() - self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + self.flag_updater.stop() + if self.cohort_poller: + self.cohort_poller.stop() def __update_cohorts(self): cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values())) @@ -99,16 +70,3 @@ def __update_cohorts(self): self.cohort_loader.download_cohorts(cohort_ids).result() except Exception as e: self.logger.warning(f"Error while updating cohorts: {e}") - - def _delete_unused_cohorts(self): - flag_cohort_ids = set() - for flag in self.flag_config_storage.get_flag_configs().values(): - flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) - - storage_cohorts = self.cohort_storage.get_cohorts() - deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids - - for deleted_cohort_id in deleted_cohort_ids: - deleted_cohort = storage_cohorts.get(deleted_cohort_id) - if deleted_cohort is not None: - self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) diff --git a/src/amplitude_experiment/flag/__init__.py b/src/amplitude_experiment/flag/__init__.py new file mode 100644 index 0000000..88e7715 --- /dev/null +++ b/src/amplitude_experiment/flag/__init__.py @@ -0,0 +1,2 @@ +from .flag_config_api import FlagConfigStreamApi +from .flag_config_updater import FlagConfigStreamer \ No newline at end of file diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index b27ea15..662f3cf 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -1,11 +1,14 @@ import json -from typing import List +import threading +from http.client import HTTPResponse, HTTPConnection, HTTPSConnection +from typing import List, Optional, Callable, Mapping, Union, Tuple -from ..evaluation.types import EvaluationFlag -from ..version import __version__ +import sseclient from ..connection_pool import HTTPConnectionPool - +from ..util.updater import get_duration_with_jitter +from ..evaluation.types import EvaluationFlag +from ..version import __version__ class FlagConfigApi: def get_flag_configs(self) -> List[EvaluationFlag]: @@ -46,3 +49,178 @@ def __setup_connection_pool(self): timeout = self.flag_config_poller_request_timeout_millis / 1000 self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, read_timeout=timeout, scheme=scheme) + + +DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000 +DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000 +DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000 + + +class EventSource: + def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int, + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS, + keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS): + self.keep_alive_timer: Optional[threading.Timer] = None + self.server_url = server_url + self.path = path + self.headers = headers + self.conn_timeout_millis = conn_timeout_millis + self.max_conn_duration_millis = max_conn_duration_millis + self.max_jitter_millis = max_jitter_millis + self.keep_alive_timeout_millis = keep_alive_timeout_millis + + self.sse: Optional[sseclient.SSEClient] = None + self.conn: Optional[HTTPConnection | HTTPSConnection] = None + self.thread: Optional[threading.Thread] = None + self._stopped = False + self.lock = threading.RLock() + + def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): + with self.lock: + if self.sse is not None: + self.sse.close() + if self.conn is not None: + self.conn.close() + + self.conn, response = self._get_conn() + if response.status != 200: + on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}") + return + + self.sse = sseclient.SSEClient(response, char_enc='utf-8') + self._stopped = False + self.thread = threading.Thread(target=self._run, args=[on_update, on_error]) + self.thread.start() + self.reset_keep_alive_timer(on_error) + + def stop(self): + with self.lock: + self._stopped = True + if self.sse: + self.sse.close() + if self.conn: + self.conn.close() + if self.keep_alive_timer: + self.keep_alive_timer.cancel() + self.sse = None + self.conn = None + # No way to stop self.thread, on self.conn.close(), + # the loop in thread will raise exception, which will terminate the thread. + + def reset_keep_alive_timer(self, on_error: Callable[[str], None]): + with self.lock: + if self.keep_alive_timer: + self.keep_alive_timer.cancel() + self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out, + args=[on_error]) + self.keep_alive_timer.start() + + def keep_alive_timed_out(self, on_error: Callable[[str], None]): + with self.lock: + if not self._stopped: + self.stop() + on_error("[Experiment] Stream flagConfigs - Keep alive timed out") + + def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): + try: + for event in self.sse.events(): + with self.lock: + if self._stopped: + return + self.reset_keep_alive_timer(on_error) + if event.data == ' ': + continue + on_update(event.data) + except TimeoutError: + # Due to connection max time reached, open another one. + with self.lock: + if self._stopped: + return + self.stop() + self.start(on_update, on_error) + except Exception as e: + # Closing connection can result in exception here as a way to stop generator. + with self.lock: + if self._stopped: + return + on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e)) + + def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]: + scheme, _, host = self.server_url.split('/', 3) + connection = HTTPConnection if scheme == 'http:' else HTTPSConnection + + body = None + + conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000) + try: + conn.request('GET', self.path, body, self.headers) + response = conn.getresponse() + except Exception as e: + conn.close() + raise e + + return conn, response + + +class FlagConfigStreamApi: + def __init__(self, + deployment_key: str, + server_url: str, + conn_timeout_millis: int, + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS): + self.deployment_key = deployment_key + self.server_url = server_url + self.conn_timeout_millis = conn_timeout_millis + self.max_conn_duration_millis = max_conn_duration_millis + self.max_jitter_millis = max_jitter_millis + + self.lock = threading.RLock() + + headers = { + 'Authorization': f"Api-Key {self.deployment_key}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" + } + + self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis) + + def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]): + with self.lock: + init_finished_event = threading.Event() + init_error_event = threading.Event() + init_updated_event = threading.Event() + + def _on_update(data): + response_json = json.loads(data) + flags = EvaluationFlag.schema().load(response_json, many=True) + if init_finished_event.is_set(): + on_update(flags) + else: + init_finished_event.set() + on_update(flags) + init_updated_event.set() + + def _on_error(data): + if init_finished_event.is_set(): + on_error(data) + else: + init_error_event.set() + init_finished_event.set() + on_error(data) + + t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error]) + t.start() + init_finished_event.wait(self.conn_timeout_millis / 1000) + if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set(): + self.stop() + on_error("stream connection timeout error") + return + + # Wait for first update callback to finish before returning. + init_updated_event.wait() + + def stop(self): + with self.lock: + threading.Thread(target=lambda: self.eventsource.stop()).start() diff --git a/src/amplitude_experiment/flag/flag_config_updater.py b/src/amplitude_experiment/flag/flag_config_updater.py new file mode 100644 index 0000000..232c75b --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_updater.py @@ -0,0 +1,268 @@ +import logging +import threading +import time +from typing import List, Callable, Optional + +from ..evaluation.types import EvaluationFlag +from ..local.config import LocalEvaluationConfig +from ..cohort.cohort_storage import CohortStorage +from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi +from ..flag.flag_config_storage import FlagConfigStorage +from ..local.poller import Poller +from ..cohort.cohort_loader import CohortLoader +from ..util.flag_config import get_all_cohort_ids_from_flag +from ..util.updater import get_duration_with_jitter + + +class FlagConfigUpdater: + def start(self, on_error: Optional[Callable[[str], None]]): + pass + + def stop(self): + pass + + +class FlagConfigUpdaterBase: + def __init__(self, + flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, + logger: logging.Logger): + self.flag_config_storage = flag_config_storage + self.cohort_loader = cohort_loader + self.cohort_storage = cohort_storage + self.logger = logger + + def update(self, flag_configs: List[EvaluationFlag]): + flag_keys = {flag.key for flag in flag_configs} + self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) + + if not self.cohort_loader: + for flag_config in flag_configs: + self.logger.debug(f"Putting non-cohort flag {flag_config.key}") + self.flag_config_storage.put_flag_config(flag_config) + return + + new_cohort_ids = set() + for flag_config in flag_configs: + new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + + existing_cohort_ids = self.cohort_storage.get_cohort_ids() + cohort_ids_to_download = new_cohort_ids - existing_cohort_ids + + # download all new cohorts + try: + self.cohort_loader.download_cohorts(cohort_ids_to_download).result() + except Exception as e: + self.logger.warning(f"Error while downloading cohorts: {e}") + + updated_cohort_ids = self.cohort_storage.get_cohort_ids() + # iterate through new flag configs and check if their required cohorts exist + for flag_config in flag_configs: + cohort_ids = get_all_cohort_ids_from_flag(flag_config) + self.logger.debug(f"Storing flag {flag_config.key}") + self.flag_config_storage.put_flag_config(flag_config) + missing_cohorts = cohort_ids - updated_cohort_ids + if missing_cohorts: + self.logger.warning(f"Flag {flag_config.key} - failed to load cohorts: {missing_cohorts}") + + # delete unused cohorts + self._delete_unused_cohorts() + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + + def _delete_unused_cohorts(self): + flag_cohort_ids = set() + for flag in self.flag_config_storage.get_flag_configs().values(): + flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) + + storage_cohorts = self.cohort_storage.get_cohorts() + deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids + + for deleted_cohort_id in deleted_cohort_ids: + deleted_cohort = storage_cohorts.get(deleted_cohort_id) + if deleted_cohort is not None: + self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) + + +class FlagConfigPoller(FlagConfigUpdaterBase, FlagConfigUpdater): + def __init__(self, flag_config_api: FlagConfigApi, flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, config: LocalEvaluationConfig, + logger: logging.Logger): + super().__init__(flag_config_storage, cohort_loader, cohort_storage, logger) + + self.flag_config_api = flag_config_api + self.flag_poller = Poller(config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + self.logger = logger + + self.on_error = None + + def start(self, on_error: Optional[Callable[[str], None]]): + self.stop() + try: + self.__update_flag_configs() + except Exception as e: + self.logger.warning(f"Error while updating flags: {e}") + raise e + self.on_error = on_error + self.flag_poller.start() + + def stop(self): + self.flag_poller.stop() + + def __periodic_flag_update(self): + try: + self.__update_flag_configs() + except Exception as e: + self.logger.warning(f"Error while updating flags: {e}") + self.stop() + if self.on_error: + self.on_error(e) + + def __update_flag_configs(self): + try: + flag_configs = self.flag_config_api.get_flag_configs() + except Exception as e: + self.logger.warning(f'Failed to fetch flag configs: {e}') + raise e + + super().update(flag_configs) + + +class FlagConfigStreamer(FlagConfigUpdaterBase, FlagConfigUpdater): + def __init__(self, flag_config_stream_api: FlagConfigStreamApi, flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, + logger: logging.Logger): + super().__init__(flag_config_storage, cohort_loader, cohort_storage, logger) + + self.flag_config_stream_api = flag_config_stream_api + self.logger = logger + + def start(self, on_error: Optional[Callable[[str], None]]): + def _on_error(err): + self.flag_config_stream_api.stop() + if on_error: + on_error(err) + + self.flag_config_stream_api.start(super().update, _on_error) + + def stop(self): + self.flag_config_stream_api.stop() + + +class FlagConfigUpdaterFallbackRetryWrapper(FlagConfigUpdater): + def __init__(self, main_updater: FlagConfigUpdater, fallback_updater: Optional[FlagConfigUpdater], + retry_delay_millis: int, max_jitter_millis: int, + fallback_start_retry_delay_millis: int, fallback_start_retry_max_jitter_millis: int, + logger: logging.Logger): + super().__init__() + + self.main_updater = main_updater + self.fallback_updater = fallback_updater + self.retry_delay_millis = retry_delay_millis + self.max_jitter_millis = max_jitter_millis + self.fallback_start_retry_delay_millis = fallback_start_retry_delay_millis + self.fallback_start_retry_max_jitter_millis = fallback_start_retry_max_jitter_millis + self.logger = logger + + self.main_retry_stopper = threading.Event() + self.fallback_retry_stopper = threading.Event() + + self.lock = threading.RLock() + + def start(self, on_error: Optional[Callable[[str], None]]): + with self.lock: + def _fallback_on_error(err: str): + pass + + def _main_on_error(err: str): + self.start_main_retry(_main_on_error) + try: + if self.fallback_updater is not None: + self.fallback_updater.start(_fallback_on_error) + except: + self.start_fallback_retry(_fallback_on_error) + + try: + self.main_updater.start(_main_on_error) + if self.fallback_updater is not None: + self.fallback_updater.stop() + self.stop_main_retry() + self.stop_fallback_retry() + except Exception as e: + if self.fallback_updater is not None: + self.fallback_updater.start(_fallback_on_error) + self.start_main_retry(_main_on_error) + else: + raise e + + def stop(self): + with self.lock: + self.main_retry_stopper.set() + self.fallback_retry_stopper.set() + self.main_updater.stop() + if self.fallback_updater is not None: + self.fallback_updater.stop() + + def start_main_retry(self, main_on_error: Callable[[str], None]): + with self.lock: + # Schedule main retry indefinitely. Only stop on some signal. + if self.main_retry_stopper: + self.main_retry_stopper.set() + + stopper = threading.Event() + + def retry_main(): + while True: + time.sleep(get_duration_with_jitter(self.retry_delay_millis, self.max_jitter_millis) / 1000) + with self.lock: + if stopper.is_set(): + break + try: + self.main_updater.start(main_on_error) + stopper.set() + self.stop_fallback_retry() + if self.fallback_updater is not None: + self.fallback_updater.stop() + break + except: + pass + + threading.Thread(target=retry_main).start() + self.main_retry_stopper = stopper + + def start_fallback_retry(self, fallback_on_error: Callable[[str], None]): + with self.lock: + # Schedule fallback retry indefinitely. Only stop on some signal. + if self.fallback_retry_stopper: + self.fallback_retry_stopper.set() + + if self.fallback_updater is None: + return + + stopper = threading.Event() + + def retry_fallback(): + while True: + time.sleep(get_duration_with_jitter(self.fallback_start_retry_delay_millis, + self.fallback_start_retry_max_jitter_millis) / 1000) + with self.lock: + if stopper.is_set(): + break + try: + if self.fallback_updater is not None: + self.fallback_updater.start(fallback_on_error) + stopper.set() + break + except: + pass + + threading.Thread(target=retry_fallback).start() + self.fallback_retry_stopper = stopper + + def stop_main_retry(self): + self.main_retry_stopper.set() + + def stop_fallback_retry(self): + self.fallback_retry_stopper.set() diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 12d7cc3..ac72b39 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -11,7 +11,7 @@ from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import InMemoryCohortStorage from ..deployment.deployment_runner import DeploymentRunner -from ..flag.flag_config_api import FlagConfigApiV2 +from ..flag.flag_config_api import FlagConfigApiV2, FlagConfigStreamApi from ..flag.flag_config_storage import InMemoryFlagConfigStorage from ..user import User from ..connection_pool import HTTPConnectionPool @@ -66,8 +66,13 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) - self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, - self.cohort_storage, self.logger, cohort_loader) + flag_config_stream_api = None + if self.config.stream_updates: + flag_config_stream_api = FlagConfigStreamApi(api_key, self.config.stream_server_url, self.config.stream_flag_conn_timeout) + + self.deployment_runner = DeploymentRunner(self.config, flag_config_api, flag_config_stream_api, + self.flag_config_storage, self.cohort_storage, self.logger, + cohort_loader) def start(self): """ diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index c729e36..2331f97 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -6,6 +6,9 @@ DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' EU_SERVER_URL = 'https://flag.lab.eu.amplitude.com' +DEFAULT_STREAM_URL = 'https://stream.lab.amplitude.com' +EU_STREAM_SERVER_URL = 'https://stream.lab.eu.amplitude.com' + class ServerZone(Enum): US = "US" @@ -20,6 +23,9 @@ def __init__(self, debug: bool = False, server_zone: ServerZone = ServerZone.US, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, + stream_updates: bool = False, + stream_server_url: str = DEFAULT_STREAM_URL, + stream_flag_conn_timeout: int = 1500, assignment_config: AssignmentConfig = None, cohort_sync_config: CohortSyncConfig = None): """ @@ -48,6 +54,12 @@ def __init__(self, debug: bool = False, cohort_sync_config.cohort_server_url == DEFAULT_COHORT_SYNC_URL): self.cohort_sync_config.cohort_server_url = EU_COHORT_SYNC_URL + self.stream_server_url = stream_server_url + if stream_server_url == DEFAULT_SERVER_URL and server_zone == ServerZone.EU: + self.stream_server_url = EU_STREAM_SERVER_URL + self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis + self.stream_updates = stream_updates + self.stream_flag_conn_timeout = stream_flag_conn_timeout self.assignment_config = assignment_config diff --git a/src/amplitude_experiment/util/updater.py b/src/amplitude_experiment/util/updater.py new file mode 100644 index 0000000..765c6e5 --- /dev/null +++ b/src/amplitude_experiment/util/updater.py @@ -0,0 +1,5 @@ +import random + + +def get_duration_with_jitter(duration: int, jitter: int): + return max(0, duration + (random.randrange(-jitter, jitter) if jitter != 0 else 0)) \ No newline at end of file diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index f64a332..337a49a 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -3,6 +3,8 @@ from unittest.mock import patch import logging +from amplitude_experiment.evaluation.types import EvaluationFlag + from src.amplitude_experiment import LocalEvaluationConfig from src.amplitude_experiment.cohort.cohort_loader import CohortLoader from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig @@ -44,6 +46,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): runner = DeploymentRunner( LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), flag_api, + None, flag_config_storage, cohort_storage, logger, @@ -63,15 +66,15 @@ def test_start_does_not_throw_if_cohort_load_fails(self): cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), - flag_api, flag_config_storage, + flag_api, None, flag_config_storage, cohort_storage, logger, cohort_loader, ) # Mock methods as needed - with patch.object(runner, '_delete_unused_cohorts'): - flag_api.get_flag_configs.return_value = [self.flag] + with patch.object(runner.flag_updater.main_updater, '_delete_unused_cohorts'): + flag_api.get_flag_configs.return_value = EvaluationFlag.schema().load([self.flag], many=True) cohort_download_api.get_cohort.side_effect = RuntimeError("test") # Simply call the method and let the test pass if no exception is raised diff --git a/tests/flag/flag_config_api_test.py b/tests/flag/flag_config_api_test.py new file mode 100644 index 0000000..e6cef97 --- /dev/null +++ b/tests/flag/flag_config_api_test.py @@ -0,0 +1,90 @@ +import json +import threading +import time +import unittest +from unittest.mock import MagicMock, patch + +from amplitude_experiment.flag.flag_config_api import FlagConfigStreamApi + + +def response(code: int, body: dict = None): + mock_response = MagicMock() + mock_response.status = code + if body is not None: + mock_response.read.return_value = json.dumps(body).encode() + return mock_response + + +BARE_FLAG = [{ + "key": "flag", + "variants": {}, + "segments": [ + { + "conditions": [ + [ + { + "selector": ["context", "user", "cohort_ids"], + "op": "set contains any", + "values": ["COHORT_ID"], + } + ] + ], + } + ] +}] + + +class FlagConfigStreamApiTest(unittest.TestCase): + def setUp(self) -> None: + self.api = FlagConfigStreamApi("deployment_key", "server_url", 2000, 5000, 0) + self.success_count = 0 + self.error_count = 0 + def on_success(self, data): + self.success_count += (1 if data[0].key == "flag" else 0) + def on_error(self, data): + self.error_count += 1 + + def test_connect_and_get_data_success(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + threading.Thread(target=self.api.start, args=[self.on_success, self.on_error]).start() + time.sleep(1) + es.call_args[0][0](json.dumps(BARE_FLAG)) + assert self.success_count == 1 + assert self.error_count == 0 + + def test_connect_timeout(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(3) + assert self.success_count == 0 + assert self.error_count == 1 + + def test_connect_error(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(1) + es.call_args[0][1]('error') + assert self.success_count == 0 + assert self.error_count == 1 + + def test_connect_success_but_error_later(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(1) + es.call_args[0][0](json.dumps(BARE_FLAG)) + assert self.success_count == 1 + assert self.error_count == 0 + es.call_args[0][1]('error') + assert self.success_count == 1 + assert self.error_count == 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/flag/flag_config_updater_test.py b/tests/flag/flag_config_updater_test.py new file mode 100644 index 0000000..b9329d6 --- /dev/null +++ b/tests/flag/flag_config_updater_test.py @@ -0,0 +1,226 @@ +import threading +import time +import unittest +from typing import Optional, Callable +from unittest import mock +from unittest.mock import patch + +from amplitude_experiment.cohort.cohort_loader import CohortLoader +from amplitude_experiment.flag import FlagConfigStreamApi +from amplitude_experiment.flag.flag_config_updater import FlagConfigUpdater, FlagConfigUpdaterFallbackRetryWrapper, \ + FlagConfigStreamer + + +class FlagConfigStreamerTest(unittest.TestCase): + def setUp(self) -> None: + self.stream_api = mock.create_autospec(FlagConfigStreamApi) + cohort_loader = mock.create_autospec(CohortLoader) + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() + logger = mock.Mock() + self.streamer = FlagConfigStreamer(self.stream_api, flag_config_storage, cohort_loader, cohort_storage, logger) + + self.err_count = 0 + + def on_error(self, msg): + self.err_count += 1 + + def test_start_then_stopped(self): + with patch.object(self.stream_api, "start") as start_func: + with patch.object(self.stream_api, "stop") as stop_func: + self.streamer.start(self.on_error) + assert start_func.call_count == 1 + assert stop_func.call_count == 0 + assert self.err_count == 0 + self.streamer.stop() + assert start_func.call_count == 1 + assert stop_func.call_count == 1 + assert self.err_count == 0 + + def test_start_then_error_stopped(self): + with patch.object(self.stream_api, "start") as start_func: + with patch.object(self.stream_api, "stop") as stop_func: + self.streamer.start(self.on_error) + assert start_func.call_count == 1 + assert stop_func.call_count == 0 + assert self.err_count == 0 + start_func.call_args[0][1]("error") + assert start_func.call_count == 1 + assert stop_func.call_count == 1 + assert self.err_count == 1 + + +class DummyUpdater(FlagConfigUpdater): + def __init__(self, name): + self.name = name + self.fail_time = -1 + self.stopped_event = threading.Event() + self.start_count = 0 + self.stop_count = 0 + + def start(self, on_error: Optional[Callable[[str], None]]): + self.start_count += 1 + print(self.name + " start") + self.stopped_event.set() + stopped_event = threading.Event() + self.stopped_event = stopped_event + + if self.fail_time == 0: + print(self.name + " start fail") + raise Exception() + if self.fail_time > 0: + def fail(): + time.sleep(self.fail_time) + if not stopped_event.is_set(): + print(self.name + " failed") + if on_error: + on_error("failed") + + threading.Thread(target=fail).start() + + def stop(self): + self.stop_count += 1 + print(self.name + " stopped") + self.stopped_event.set() + + +class FlagConfigUpdaterFallbackRetryWrapperTest(unittest.TestCase): + def setUp(self) -> None: + self.dummy1 = DummyUpdater("dummy1") + self.dummy2 = DummyUpdater("dummy2") + self.logger = mock.Mock() + self.api = FlagConfigUpdaterFallbackRetryWrapper(self.dummy1, self.dummy2, 1000, 0, 500, 0, self.logger) + + def test_main_start_success(self): + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + + def test_main_start_failed_fallback_start_success_main_start_success_on_retry(self): + self.dummy1.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = -1 + time.sleep(1.1) + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + # No more restarts + time.sleep(2) + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + + def test_main_start_failed_fallback_start_failed(self): + self.dummy1.fail_time = 0 + self.dummy2.fail_time = 0 + try: + self.api.start(None) + raise Exception() + except: + pass + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + time.sleep(2) + # No retry + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + + def test_main_start_success_later_failed_fallback_start_success_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = -1 + time.sleep(1) + # Now main failed + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + + def test_main_start_success_later_failed_fallback_start_failed_later_fallback_retry_success_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = 0 + time.sleep(0.5) + # Fallback retried + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 2 + time.sleep(0.5) + # Main and Fallback retried + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 3 + self.dummy1.fail_time = -1 + self.dummy2.fail_time = -1 + time.sleep(0.5) + # Fallback retried success + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 4 + time.sleep(0.5) + # Main retried success + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count == 4 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count == 4 + + def test_main_start_success_later_failed_fallback_start_failed_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = 0 + time.sleep(0.5) + # Fallback retried + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 2 + time.sleep(0.5) + # Main and Fallback retried + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 3 + self.dummy1.fail_time = -1 + time.sleep(0.5) + # Fallback retried still fail + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 4 + time.sleep(0.5) + # Main retried success, fallback may or may not retry, but no more after main success + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count <= 5 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count <= 5 + + def test_main_start_success_later_failed_fallback_start_failed_later_stopped(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + self.dummy1.fail_time = 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.api.stop() diff --git a/tests/local/client_test.py b/tests/local/client_test.py index b6c50eb..9b4b49b 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -15,6 +15,7 @@ class LocalEvaluationClientTestCase(unittest.TestCase): _local_evaluation_client: LocalEvaluationClient = None + _stream_update: bool = False @classmethod def setUpClass(cls) -> None: @@ -25,7 +26,8 @@ def setUpClass(cls) -> None: secret_key=secret_key) cls._local_evaluation_client = ( LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, - cohort_sync_config=cohort_sync_config))) + cohort_sync_config=cohort_sync_config, + stream_updates=cls._stream_update))) cls._local_evaluation_client.start() @classmethod @@ -110,5 +112,9 @@ def test_evaluation_cohorts_not_in_storage_with_sync_config(self): self.assertTrue(any(re.match(log_message, message) for message in log.output)) +class LocalEvaluationClientStreamingTestCase(LocalEvaluationClientTestCase): + _stream_update: bool = True + + if __name__ == '__main__': unittest.main()