From f017510e0bd0d11e42802c2b7cfb0bd676c23071 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 21 Dec 2024 02:44:21 -0800 Subject: [PATCH 1/5] Add stream feat --- requirements.txt | 3 +- .../deployment/deployment_runner.py | 91 ++---- src/amplitude_experiment/flag/__init__.py | 2 + .../flag/flag_config_api.py | 186 ++++++++++++- .../flag/flag_config_updater.py | 260 ++++++++++++++++++ src/amplitude_experiment/flag/main.py | 28 ++ src/amplitude_experiment/local/client.py | 11 +- src/amplitude_experiment/local/config.py | 12 + src/amplitude_experiment/util/updater.py | 5 + tests/flag/flag_config_api_test.py | 72 +++++ tests/flag/flag_config_updater_test.py | 222 +++++++++++++++ tests/local/client_test.py | 8 +- 12 files changed, 823 insertions(+), 77 deletions(-) create mode 100644 src/amplitude_experiment/flag/__init__.py create mode 100644 src/amplitude_experiment/flag/flag_config_updater.py create mode 100644 src/amplitude_experiment/flag/main.py create mode 100644 src/amplitude_experiment/util/updater.py create mode 100644 tests/flag/flag_config_api_test.py create mode 100644 tests/flag/flag_config_updater_test.py diff --git a/requirements.txt b/requirements.txt index 26cc8cb..54c9407 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -amplitude_analytics~=1.1.1 \ No newline at end of file +amplitude_analytics~=1.1.1 +sseclient-py~=1.8.0 \ No newline at end of file diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index aa8aa64..e291398 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -2,20 +2,25 @@ from typing import Optional import threading +from ..flag.flag_config_updater import FlagConfigPoller, FlagConfigStreamer, FlagConfigUpdaterFallbackRetryWrapper from ..local.config import LocalEvaluationConfig from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import CohortStorage -from ..flag.flag_config_api import FlagConfigApi +from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags +streamUpdaterRetryDelayMillis = 15000 +updaterRetryMaxJitterMillis = 1000 + class DeploymentRunner: def __init__( self, config: LocalEvaluationConfig, flag_config_api: FlagConfigApi, + flag_config_stream_api: Optional[FlagConfigStreamApi], flag_config_storage: FlagConfigStorage, cohort_storage: CohortStorage, logger: logging.Logger, @@ -27,7 +32,18 @@ def __init__( self.cohort_storage = cohort_storage self.cohort_loader = cohort_loader self.lock = threading.Lock() - self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + self.flag_updater = FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, + config, logger) + if flag_config_stream_api: + self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper( + FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger), + self.flag_updater, + streamUpdaterRetryDelayMillis, updaterRetryMaxJitterMillis, + config.flag_config_polling_interval_millis, 0, + logger + ) + + self.cohort_poller = None if self.cohort_loader: self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000, self.__update_cohorts) @@ -35,63 +51,15 @@ def __init__( def start(self): with self.lock: - self.__update_flag_configs() - self.flag_poller.start() + self.flag_updater.start(None) + print("flag updater start finished") if self.cohort_loader: self.cohort_poller.start() def stop(self): - self.flag_poller.stop() - - def __periodic_flag_update(self): - try: - self.__update_flag_configs() - except Exception as e: - self.logger.warning(f"Error while updating flags: {e}") - - def __update_flag_configs(self): - try: - flag_configs = self.flag_config_api.get_flag_configs() - except Exception as e: - self.logger.warning(f'Failed to fetch flag configs: {e}') - raise e - - flag_keys = {flag['key'] for flag in flag_configs} - self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) - - if not self.cohort_loader: - for flag_config in flag_configs: - self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") - self.flag_config_storage.put_flag_config(flag_config) - return - - new_cohort_ids = set() - for flag_config in flag_configs: - new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) - - existing_cohort_ids = self.cohort_storage.get_cohort_ids() - cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - - # download all new cohorts - try: - self.cohort_loader.download_cohorts(cohort_ids_to_download).result() - except Exception as e: - self.logger.warning(f"Error while downloading cohorts: {e}") - - # get updated set of cohort ids - updated_cohort_ids = self.cohort_storage.get_cohort_ids() - # iterate through new flag configs and check if their required cohorts exist - for flag_config in flag_configs: - cohort_ids = get_all_cohort_ids_from_flag(flag_config) - self.logger.debug(f"Storing flag {flag_config['key']}") - self.flag_config_storage.put_flag_config(flag_config) - missing_cohorts = cohort_ids - updated_cohort_ids - if missing_cohorts: - self.logger.warning(f"Flag {flag_config['key']} - failed to load cohorts: {missing_cohorts}") - - # delete unused cohorts - self._delete_unused_cohorts() - self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + self.flag_updater.stop() + if self.cohort_poller: + self.cohort_poller.stop() def __update_cohorts(self): cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values())) @@ -99,16 +67,3 @@ def __update_cohorts(self): self.cohort_loader.download_cohorts(cohort_ids).result() except Exception as e: self.logger.warning(f"Error while updating cohorts: {e}") - - def _delete_unused_cohorts(self): - flag_cohort_ids = set() - for flag in self.flag_config_storage.get_flag_configs().values(): - flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) - - storage_cohorts = self.cohort_storage.get_cohorts() - deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids - - for deleted_cohort_id in deleted_cohort_ids: - deleted_cohort = storage_cohorts.get(deleted_cohort_id) - if deleted_cohort is not None: - self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) diff --git a/src/amplitude_experiment/flag/__init__.py b/src/amplitude_experiment/flag/__init__.py new file mode 100644 index 0000000..88e7715 --- /dev/null +++ b/src/amplitude_experiment/flag/__init__.py @@ -0,0 +1,2 @@ +from .flag_config_api import FlagConfigStreamApi +from .flag_config_updater import FlagConfigStreamer \ No newline at end of file diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index 15db645..f06db0a 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -1,10 +1,15 @@ import json -from typing import List +import random +import threading +import time +from http.client import HTTPResponse, HTTPConnection, HTTPSConnection +from typing import List, Tuple, Optional, Callable, Mapping -from ..version import __version__ - -from ..connection_pool import HTTPConnectionPool +import sseclient +from ..connection_pool import HTTPConnectionPool, WrapperHTTPConnection +from ..util.updater import get_duration_with_jitter +from ..version import __version__ class FlagConfigApi: def get_flag_configs(self) -> List: @@ -45,3 +50,176 @@ def __setup_connection_pool(self): timeout = self.flag_config_poller_request_timeout_millis / 1000 self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, read_timeout=timeout, scheme=scheme) + + +streamApiKeepaliveTimeout = 17000 +streamApiReconnIntervalMillis = 15 * 60 * 1000 +streamApiMaxJitterMillis = 5000 + + +class EventSource: + def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int, + max_conn_duration_millis: int = streamApiReconnIntervalMillis, + max_jitter_millis: int = streamApiMaxJitterMillis, + keep_alive_timeout_millis: int = streamApiKeepaliveTimeout): + self.keep_alive_timer: Optional[threading.Timer] = None + self.server_url = server_url + self.path = path + self.headers = headers + self.conn_timeout_millis = conn_timeout_millis + self.max_conn_duration_millis = max_conn_duration_millis + self.max_jitter_millis = max_jitter_millis + self.keep_alive_timeout_millis = keep_alive_timeout_millis + + self.sse: Optional[sseclient.SSEClient] = None + self.conn: Optional[HTTPConnection | HTTPSConnection] = None + self.thread: Optional[threading.Thread] = None + self._stopped = False + self.lock = threading.RLock() + + def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): + with self.lock: + if self.sse is not None: + self.sse.close() + if self.conn is not None: + self.conn.close() + + self.conn, response = self._get_conn() + if response.status != 200: + on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}") + return + + self.sse = sseclient.SSEClient(response, char_enc='utf-8') + self._stopped = False + self.thread = threading.Thread(target=self._run, args=[on_update, on_error]) + self.thread.start() + self.reset_keep_alive_timer(on_error) + + def stop(self): + with self.lock: + self._stopped = True + if self.sse: + self.sse.close() + if self.conn: + self.conn.close() + if self.keep_alive_timer: + self.keep_alive_timer.cancel() + self.sse = None + self.conn = None + # No way to stop self.thread, on self.conn.close(), + # the loop in thread will raise exception, which will terminate the thread. + + def reset_keep_alive_timer(self, on_error: Callable[[str], None]): + with self.lock: + if self.keep_alive_timer: + self.keep_alive_timer.cancel() + self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis, self.keep_alive_timed_out, args=[on_error]) + self.keep_alive_timer.start() + + def keep_alive_timed_out(self, on_error: Callable[[str], None]): + with self.lock: + if self.conn and self.sse: + self.stop() + on_error("[Experiment] Stream flagConfigs - Keep alive timed out") + + def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): + try: + for event in self.sse.events(): + with self.lock: + if self._stopped: + return + self.reset_keep_alive_timer(on_error) + if event.data == ' ': + continue + on_update(event.data) + except TimeoutError: + # Due to connection max time reached, open another one. + with self.lock: + if self._stopped: + return + self.stop() + self.start(on_update, on_error) + except Exception as e: + # Closing connection can result in exception here as a way to stop generator. + with self.lock: + if self._stopped: + return + on_error(e) + + def _get_conn(self) -> tuple[HTTPConnection | HTTPSConnection, HTTPResponse]: + scheme, _, host = self.server_url.split('/', 3) + connection = HTTPConnection if scheme == 'http:' else HTTPSConnection + + body = None + + conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000) + try: + conn.request('GET', self.path, body, self.headers) + response = conn.getresponse() + except Exception as e: + conn.close() + raise e + + return conn, response + + +class FlagConfigStreamApi: + def __init__(self, + deployment_key: str, + server_url: str, + conn_timeout_millis: int, + max_conn_duration_millis: int = streamApiReconnIntervalMillis, + max_jitter_millis: int = streamApiMaxJitterMillis): + self.deployment_key = deployment_key + self.server_url = server_url + self.conn_timeout_millis = conn_timeout_millis + self.max_conn_duration_millis = max_conn_duration_millis + self.max_jitter_millis = max_jitter_millis + + self.lock = threading.RLock() + + headers = { + 'Authorization': f"Api-Key {self.deployment_key}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" + } + + self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis) + + def start(self, on_update: Callable[[List], None], on_error: Callable[[str], None]): + with self.lock: + init_finished_event = threading.Event() + init_error_event = threading.Event() + init_updated_event = threading.Event() + + def _on_update(data): + flags = json.loads(data) + if init_finished_event.is_set(): + on_update(flags) + else: + init_finished_event.set() + on_update(flags) + init_updated_event.set() + + def _on_error(data): + if init_finished_event.is_set(): + on_error(data) + else: + init_error_event.set() + init_finished_event.set() + on_error(data) + + t = threading.Thread(target=lambda: self.eventsource.start(_on_update, _on_error)) + t.start() + init_finished_event.wait(self.conn_timeout_millis / 1000) + if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set(): + self.stop() + on_error("stream connection timeout error") + return + + # Wait for first update callback to finish before returning. + init_updated_event.wait() + + def stop(self): + with self.lock: + threading.Thread(target=lambda: self.eventsource.stop()).start() diff --git a/src/amplitude_experiment/flag/flag_config_updater.py b/src/amplitude_experiment/flag/flag_config_updater.py new file mode 100644 index 0000000..8eb2770 --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_updater.py @@ -0,0 +1,260 @@ +import logging +import threading +import time +from typing import List, Callable, Optional + +from ..local.config import LocalEvaluationConfig +from ..cohort.cohort_storage import CohortStorage +from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi +from ..flag.flag_config_storage import FlagConfigStorage +from ..local.poller import Poller +from ..cohort.cohort_loader import CohortLoader +from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags +from ..util.updater import get_duration_with_jitter + + +class FlagConfigUpdater: + def start(self, on_error: Optional[Callable[[str], None]]): + pass + + def stop(self): + pass + + +class FlagConfigUpdaterBase(): + def __init__(self, + flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, + logger: logging.Logger): + self.flag_config_storage = flag_config_storage + self.cohort_loader = cohort_loader + self.cohort_storage = cohort_storage + self.logger = logger + + def update(self, flag_configs: List): + flag_keys = {flag['key'] for flag in flag_configs} + self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) + + if not self.cohort_loader: + for flag_config in flag_configs: + self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + return + + new_cohort_ids = set() + for flag_config in flag_configs: + new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + + existing_cohort_ids = self.cohort_storage.get_cohort_ids() + cohort_ids_to_download = new_cohort_ids - existing_cohort_ids + + # download all new cohorts + try: + self.cohort_loader.download_cohorts(cohort_ids_to_download).result() + print("cohort downloaded") + except Exception as e: + print("cohort error") + self.logger.warning(f"Error while downloading cohorts: {e}") + + # get updated set of cohort ids + updated_cohort_ids = self.cohort_storage.get_cohort_ids() + # iterate through new flag configs and check if their required cohorts exist + for flag_config in flag_configs: + cohort_ids = get_all_cohort_ids_from_flag(flag_config) + self.logger.debug(f"Storing flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + missing_cohorts = cohort_ids - updated_cohort_ids + if missing_cohorts: + self.logger.warning(f"Flag {flag_config['key']} - failed to load cohorts: {missing_cohorts}") + + # delete unused cohorts + self._delete_unused_cohorts() + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + + def _delete_unused_cohorts(self): + flag_cohort_ids = set() + for flag in self.flag_config_storage.get_flag_configs().values(): + flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) + + storage_cohorts = self.cohort_storage.get_cohorts() + deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids + + for deleted_cohort_id in deleted_cohort_ids: + deleted_cohort = storage_cohorts.get(deleted_cohort_id) + if deleted_cohort is not None: + self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) + + +class FlagConfigPoller(FlagConfigUpdaterBase): + def __init__(self, flag_config_api: FlagConfigApi, flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, config: LocalEvaluationConfig, + logger: logging.Logger): + super().__init__(flag_config_storage, cohort_loader, cohort_storage, logger) + + self.flag_config_api = flag_config_api + self.flag_poller = Poller(config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + self.logger = logger + + self.on_error = None + + def start(self, on_error: Optional[Callable[[str], None]]): + self.stop() + try: + self.__update_flag_configs() + except Exception as e: + self.logger.warning(f"Error while updating flags: {e}") + self.on_error = on_error + self.flag_poller.start() + + def stop(self): + self.flag_poller.stop() + + def __periodic_flag_update(self): + try: + self.__update_flag_configs() + except Exception as e: + self.stop() + if self.on_error: + self.on_error(e) + + def __update_flag_configs(self): + try: + flag_configs = self.flag_config_api.get_flag_configs() + except Exception as e: + self.logger.warning(f'Failed to fetch flag configs: {e}') + raise e + + super().update(flag_configs) + + +class FlagConfigStreamer(FlagConfigUpdaterBase): + def __init__(self, flag_config_stream_api: FlagConfigStreamApi, flag_config_storage: FlagConfigStorage, + cohort_loader: CohortLoader, + cohort_storage: CohortStorage, + logger: logging.Logger): + super().__init__(flag_config_storage, cohort_loader, cohort_storage, logger) + + self.flag_config_stream_api = flag_config_stream_api + self.logger = logger + + def start(self, on_error: Optional[Callable[[str], None]]): + def _on_error(err): + self.flag_config_stream_api.stop() + if on_error: + on_error(err) + + self.flag_config_stream_api.start(super().update, _on_error) + + def stop(self): + self.flag_config_stream_api.stop() + + +class FlagConfigUpdaterFallbackRetryWrapper(FlagConfigUpdater): + def __init__(self, main_updater: FlagConfigUpdater, fallback_updater: FlagConfigUpdater, + retry_delay_millis: int, max_jitter_millis: int, + fallback_start_retry_delay_millis: int, fallback_start_retry_max_jitter_millis: int, + logger: logging.Logger): + super().__init__() + + self.main_updater = main_updater + self.fallback_updater = fallback_updater + self.retry_delay_millis = retry_delay_millis + self.max_jitter_millis = max_jitter_millis + self.fallback_start_retry_delay_millis = fallback_start_retry_delay_millis + self.fallback_start_retry_max_jitter_millis = fallback_start_retry_max_jitter_millis + self.logger = logger + + self.main_retry_stopper = threading.Event() + self.fallback_retry_stopper = threading.Event() + + self.lock = threading.RLock() + + def start(self, on_error: Optional[Callable[[str], None]]): + with self.lock: + def _fallback_on_error(err: str): + pass + + def _main_on_error(err: str): + self.start_main_retry(_main_on_error) + try: + self.fallback_updater.start(_fallback_on_error) + except: + self.start_fallback_retry(_fallback_on_error) + + try: + self.main_updater.start(_main_on_error) + self.fallback_updater.stop() + self.stop_main_retry() + self.stop_fallback_retry() + except Exception as e: + if self.fallback_updater is not None: + self.fallback_updater.start(_fallback_on_error) + self.start_main_retry(_main_on_error) + else: + raise e + + def stop(self): + with self.lock: + self.main_retry_stopper.set() + self.fallback_retry_stopper.set() + self.main_updater.stop() + self.fallback_updater.stop() + + def start_main_retry(self, main_on_error: Callable[[str], None]): + with self.lock: + # Schedule main retry indefinitely. Only stop on some signal. + if self.main_retry_stopper: + self.main_retry_stopper.set() + + stopper = threading.Event() + + def retry_main(): + while True: + time.sleep(get_duration_with_jitter(self.retry_delay_millis, self.max_jitter_millis) / 1000) + with self.lock: + if stopper.is_set(): + break + try: + self.main_updater.start(main_on_error) + stopper.set() + self.stop_fallback_retry() + self.fallback_updater.stop() + break + except: + pass + + threading.Thread(target=retry_main).start() + self.main_retry_stopper = stopper + + def start_fallback_retry(self, fallback_on_error: Callable[[str], None]): + with self.lock: + # Schedule fallback retry indefinitely. Only stop on some signal. + if self.fallback_retry_stopper: + self.fallback_retry_stopper.set() + + stopper = threading.Event() + + def retry_fallback(): + while True: + time.sleep(get_duration_with_jitter(self.fallback_start_retry_delay_millis, + self.fallback_start_retry_max_jitter_millis) / 1000) + with self.lock: + if stopper.is_set(): + break + try: + self.fallback_updater.start(fallback_on_error) + stopper.set() + break + except: + pass + + threading.Thread(target=retry_fallback).start() + self.fallback_retry_stopper = stopper + + def stop_main_retry(self): + self.main_retry_stopper.set() + + def stop_fallback_retry(self): + self.fallback_retry_stopper.set() diff --git a/src/amplitude_experiment/flag/main.py b/src/amplitude_experiment/flag/main.py new file mode 100644 index 0000000..c874120 --- /dev/null +++ b/src/amplitude_experiment/flag/main.py @@ -0,0 +1,28 @@ +import json +import logging +import time + +from amplitude_experiment.flag import FlagConfigStreamApi, FlagConfigStreamer +from amplitude_experiment.flag.flag_config_storage import FlagConfigStorage, InMemoryFlagConfigStorage +from amplitude_experiment.flag.flag_config_updater import DummyUpdater, FlagConfigUpdaterFallbackRetryWrapper + +if __name__ == '__main__': + logger = logging.Logger("a") + api = FlagConfigStreamApi('server-tUTqR62DZefq7c73zMpbIr1M5VDtwY8T', 'https://skylab-stream.stag2.amplitude.com', 1500, 1000 * 5, 0) + storage = InMemoryFlagConfigStorage() + # streamer = FlagConfigStreamer(api, storage, None, None, logger) + + dummy1 = DummyUpdater("dummy 1") + # dummy1.start_fail = True + dummy1.fail = True + dummy2 = DummyUpdater("dummy 2") + dummy2.start_fail = True + streamer = FlagConfigUpdaterFallbackRetryWrapper(dummy1, dummy2, logger) + + print("start") + streamer.start(print) + # print(storage.get_flag_configs()) + time.sleep(20) + + streamer.stop() + print("done") \ No newline at end of file diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index ba917d4..be02d57 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -13,7 +13,7 @@ from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import InMemoryCohortStorage from ..deployment.deployment_runner import DeploymentRunner -from ..flag.flag_config_api import FlagConfigApiV2 +from ..flag.flag_config_api import FlagConfigApiV2, FlagConfigStreamApi from ..flag.flag_config_storage import InMemoryFlagConfigStorage from ..user import User from ..connection_pool import HTTPConnectionPool @@ -67,8 +67,13 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) - self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, - self.cohort_storage, self.logger, cohort_loader) + flag_config_stream_api = None + if self.config.stream_updates: + flag_config_stream_api = FlagConfigStreamApi(api_key, self.config.stream_server_url, self.config.stream_flag_conn_timeout) + + self.deployment_runner = DeploymentRunner(self.config, flag_config_api, flag_config_stream_api, + self.flag_config_storage, self.cohort_storage, self.logger, + cohort_loader) def start(self): """ diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index c729e36..2331f97 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -6,6 +6,9 @@ DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' EU_SERVER_URL = 'https://flag.lab.eu.amplitude.com' +DEFAULT_STREAM_URL = 'https://stream.lab.amplitude.com' +EU_STREAM_SERVER_URL = 'https://stream.lab.eu.amplitude.com' + class ServerZone(Enum): US = "US" @@ -20,6 +23,9 @@ def __init__(self, debug: bool = False, server_zone: ServerZone = ServerZone.US, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, + stream_updates: bool = False, + stream_server_url: str = DEFAULT_STREAM_URL, + stream_flag_conn_timeout: int = 1500, assignment_config: AssignmentConfig = None, cohort_sync_config: CohortSyncConfig = None): """ @@ -48,6 +54,12 @@ def __init__(self, debug: bool = False, cohort_sync_config.cohort_server_url == DEFAULT_COHORT_SYNC_URL): self.cohort_sync_config.cohort_server_url = EU_COHORT_SYNC_URL + self.stream_server_url = stream_server_url + if stream_server_url == DEFAULT_SERVER_URL and server_zone == ServerZone.EU: + self.stream_server_url = EU_STREAM_SERVER_URL + self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis + self.stream_updates = stream_updates + self.stream_flag_conn_timeout = stream_flag_conn_timeout self.assignment_config = assignment_config diff --git a/src/amplitude_experiment/util/updater.py b/src/amplitude_experiment/util/updater.py new file mode 100644 index 0000000..765c6e5 --- /dev/null +++ b/src/amplitude_experiment/util/updater.py @@ -0,0 +1,5 @@ +import random + + +def get_duration_with_jitter(duration: int, jitter: int): + return max(0, duration + (random.randrange(-jitter, jitter) if jitter != 0 else 0)) \ No newline at end of file diff --git a/tests/flag/flag_config_api_test.py b/tests/flag/flag_config_api_test.py new file mode 100644 index 0000000..9159d9e --- /dev/null +++ b/tests/flag/flag_config_api_test.py @@ -0,0 +1,72 @@ +import json +import logging +import threading +import time +import unittest +from unittest.mock import MagicMock, patch + +from amplitude_experiment.flag.flag_config_api import FlagConfigStreamApi + + +def response(code: int, body: dict = None): + mock_response = MagicMock() + mock_response.status = code + if body is not None: + mock_response.read.return_value = json.dumps(body).encode() + return mock_response + + +class FlagConfigStreamApiTest(unittest.TestCase): + def setUp(self) -> None: + self.api = FlagConfigStreamApi("deployment_key", "server_url", 2000, 5000, 0) + self.success_count = 0 + self.error_count = 0 + def on_success(self, data): + self.success_count += (1 if data == "apple" else 0) + def on_error(self, data): + self.error_count += 1 + + def test_connect_and_get_data_success(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(1) + es.call_args[0][0]('"apple"') + assert self.success_count == 1 + assert self.error_count == 0 + + def test_connect_timeout(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(3) + assert self.success_count == 0 + assert self.error_count == 1 + + def test_connect_error(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(1) + es.call_args[0][1]('error') + assert self.success_count == 0 + assert self.error_count == 1 + + def test_connect_success_but_error_later(self): + with patch.object(self.api.eventsource, 'start') as es: + assert self.success_count == 0 + assert self.error_count == 0 + threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + time.sleep(1) + es.call_args[0][0]('"apple"') + assert self.success_count == 1 + assert self.error_count == 0 + es.call_args[0][1]('error') + assert self.success_count == 1 + assert self.error_count == 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/flag/flag_config_updater_test.py b/tests/flag/flag_config_updater_test.py new file mode 100644 index 0000000..d93b7e6 --- /dev/null +++ b/tests/flag/flag_config_updater_test.py @@ -0,0 +1,222 @@ +import threading +import time +import unittest +from typing import Optional, Callable +from unittest import mock +from unittest.mock import patch + +from amplitude_experiment.cohort.cohort_loader import CohortLoader +from amplitude_experiment.flag import FlagConfigStreamApi +from amplitude_experiment.flag.flag_config_updater import FlagConfigUpdater, FlagConfigUpdaterFallbackRetryWrapper, \ + FlagConfigStreamer + + +class FlagConfigStreamerTest(unittest.TestCase): + def setUp(self) -> None: + self.stream_api = mock.create_autospec(FlagConfigStreamApi) + cohort_loader = mock.create_autospec(CohortLoader) + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() + logger = mock.Mock() + self.streamer = FlagConfigStreamer(self.stream_api, flag_config_storage, cohort_loader, cohort_storage, logger) + + self.err_count = 0 + + def on_error(self, msg): + self.err_count += 1 + + def test_start_then_stopped(self): + with patch.object(self.stream_api, "start") as start_func: + with patch.object(self.stream_api, "stop") as stop_func: + self.streamer.start(self.on_error) + assert start_func.call_count == 1 + assert stop_func.call_count == 0 + assert self.err_count == 0 + self.streamer.stop() + assert start_func.call_count == 1 + assert stop_func.call_count == 1 + assert self.err_count == 0 + + def test_start_then_error_stopped(self): + with patch.object(self.stream_api, "start") as start_func: + with patch.object(self.stream_api, "stop") as stop_func: + self.streamer.start(self.on_error) + assert start_func.call_count == 1 + assert stop_func.call_count == 0 + assert self.err_count == 0 + start_func.call_args[0][1]("error") + assert start_func.call_count == 1 + assert stop_func.call_count == 1 + assert self.err_count == 1 + + +class DummyUpdater(FlagConfigUpdater): + def __init__(self, name): + self.name = name + self.fail_time = -1 + self.stopped_event = threading.Event() + self.start_count = 0 + self.stop_count = 0 + + def start(self, on_error: Optional[Callable[[str], None]]): + self.start_count += 1 + print(self.name + " start") + self.stopped_event.set() + stopped_event = threading.Event() + self.stopped_event = stopped_event + + if self.fail_time == 0: + print(self.name + " start fail") + raise Exception() + if self.fail_time > 0: + def fail(): + time.sleep(self.fail_time) + if not stopped_event.is_set(): + print(self.name + " failed") + if on_error: + on_error("failed") + + threading.Thread(target=fail).start() + + def stop(self): + self.stop_count += 1 + print(self.name + " stopped") + self.stopped_event.set() + + +class FlagConfigUpdaterFallbackRetryWrapperTest(unittest.TestCase): + def setUp(self) -> None: + self.dummy1 = DummyUpdater("dummy1") + self.dummy2 = DummyUpdater("dummy2") + self.logger = mock.Mock() + self.api = FlagConfigUpdaterFallbackRetryWrapper(self.dummy1, self.dummy2, 1000, 0, 500, 0, self.logger) + + def test_main_start_success(self): + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + + def test_main_start_failed_fallback_start_success_main_start_success_on_retry(self): + self.dummy1.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = 0 + time.sleep(1.1) + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + + def test_main_start_failed_fallback_start_failed(self): + self.dummy1.fail_time = 0 + self.dummy2.fail_time = 0 + try: + self.api.start(None) + raise Exception() + except: + pass + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + time.sleep(2) + # No retry + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + + def test_main_start_success_later_failed_fallback_start_success_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = -1 + time.sleep(1) + # Now main failed + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 + + def test_main_start_success_later_failed_fallback_start_failed_later_fallback_retry_success_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = 0 + time.sleep(0.5) + # Fallback retried + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 2 + time.sleep(0.5) + # Main and Fallback retried + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 3 + self.dummy1.fail_time = -1 + self.dummy2.fail_time = -1 + time.sleep(0.5) + # Fallback retried success + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 4 + time.sleep(0.5) + # Main retried success + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count == 4 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count == 4 + + def test_main_start_success_later_failed_fallback_start_failed_later_main_retry_success(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + self.dummy1.fail_time = 0 + time.sleep(0.5) + # Fallback retried + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 2 + time.sleep(0.5) + # Main and Fallback retried + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 3 + self.dummy1.fail_time = -1 + time.sleep(0.5) + # Fallback retried still fail + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 4 + time.sleep(0.5) + # Main retried success, fallback may or may not retry, but no more after main success + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count <= 5 + time.sleep(2) + # No more retry + assert self.dummy1.start_count == 3 + assert self.dummy2.start_count <= 5 + + def test_main_start_success_later_failed_fallback_start_failed_later_stopped(self): + self.dummy1.fail_time = 2 + self.dummy2.fail_time = 0 + self.api.start(None) + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 0 + self.dummy1.fail_time = 0 + time.sleep(2.1) + # Now main failed + assert self.dummy1.start_count == 1 + assert self.dummy2.start_count == 1 + diff --git a/tests/local/client_test.py b/tests/local/client_test.py index b6c50eb..9b4b49b 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -15,6 +15,7 @@ class LocalEvaluationClientTestCase(unittest.TestCase): _local_evaluation_client: LocalEvaluationClient = None + _stream_update: bool = False @classmethod def setUpClass(cls) -> None: @@ -25,7 +26,8 @@ def setUpClass(cls) -> None: secret_key=secret_key) cls._local_evaluation_client = ( LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, - cohort_sync_config=cohort_sync_config))) + cohort_sync_config=cohort_sync_config, + stream_updates=cls._stream_update))) cls._local_evaluation_client.start() @classmethod @@ -110,5 +112,9 @@ def test_evaluation_cohorts_not_in_storage_with_sync_config(self): self.assertTrue(any(re.match(log_message, message) for message in log.output)) +class LocalEvaluationClientStreamingTestCase(LocalEvaluationClientTestCase): + _stream_update: bool = True + + if __name__ == '__main__': unittest.main() From 2e0d8af1e0235d34c560c4403f4e91548b4e1571 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 21 Dec 2024 06:38:59 -0800 Subject: [PATCH 2/5] Use EvaluationFlag, fix keep alive ms, rename defaults, fix tests --- .../deployment/deployment_runner.py | 15 ++++++---- .../flag/flag_config_api.py | 30 ++++++++++--------- .../flag/flag_config_updater.py | 25 +++++++++++----- src/amplitude_experiment/flag/main.py | 28 ----------------- tests/deployment/deployment_runner_test.py | 9 ++++-- tests/flag/flag_config_api_test.py | 29 +++++++++++++++--- tests/flag/flag_config_updater_test.py | 8 +++-- 7 files changed, 79 insertions(+), 65 deletions(-) delete mode 100644 src/amplitude_experiment/flag/main.py diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index e291398..000af7f 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -11,8 +11,8 @@ from ..local.poller import Poller from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags -streamUpdaterRetryDelayMillis = 15000 -updaterRetryMaxJitterMillis = 1000 +DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000 +DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000 class DeploymentRunner: @@ -32,13 +32,17 @@ def __init__( self.cohort_storage = cohort_storage self.cohort_loader = cohort_loader self.lock = threading.Lock() - self.flag_updater = FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, - config, logger) + self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper( + FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, config, logger), + None, + 0, 0, config.flag_config_polling_interval_millis, 0, + logger + ) if flag_config_stream_api: self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper( FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger), self.flag_updater, - streamUpdaterRetryDelayMillis, updaterRetryMaxJitterMillis, + DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS, DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS, config.flag_config_polling_interval_millis, 0, logger ) @@ -52,7 +56,6 @@ def __init__( def start(self): with self.lock: self.flag_updater.start(None) - print("flag updater start finished") if self.cohort_loader: self.cohort_poller.start() diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index e5ebaad..9411546 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -53,16 +53,16 @@ def __setup_connection_pool(self): read_timeout=timeout, scheme=scheme) -streamApiKeepaliveTimeout = 17000 -streamApiReconnIntervalMillis = 15 * 60 * 1000 -streamApiMaxJitterMillis = 5000 +DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000 +DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000 +DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000 class EventSource: def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int, - max_conn_duration_millis: int = streamApiReconnIntervalMillis, - max_jitter_millis: int = streamApiMaxJitterMillis, - keep_alive_timeout_millis: int = streamApiKeepaliveTimeout): + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS, + keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS): self.keep_alive_timer: Optional[threading.Timer] = None self.server_url = server_url self.path = path @@ -114,12 +114,13 @@ def reset_keep_alive_timer(self, on_error: Callable[[str], None]): with self.lock: if self.keep_alive_timer: self.keep_alive_timer.cancel() - self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis, self.keep_alive_timed_out, args=[on_error]) + self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out, + args=[on_error]) self.keep_alive_timer.start() def keep_alive_timed_out(self, on_error: Callable[[str], None]): with self.lock: - if self.conn and self.sse: + if not self._stopped: self.stop() on_error("[Experiment] Stream flagConfigs - Keep alive timed out") @@ -145,7 +146,7 @@ def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None] with self.lock: if self._stopped: return - on_error(e) + on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e)) def _get_conn(self) -> tuple[HTTPConnection | HTTPSConnection, HTTPResponse]: scheme, _, host = self.server_url.split('/', 3) @@ -169,8 +170,8 @@ def __init__(self, deployment_key: str, server_url: str, conn_timeout_millis: int, - max_conn_duration_millis: int = streamApiReconnIntervalMillis, - max_jitter_millis: int = streamApiMaxJitterMillis): + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS): self.deployment_key = deployment_key self.server_url = server_url self.conn_timeout_millis = conn_timeout_millis @@ -187,14 +188,15 @@ def __init__(self, self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis) - def start(self, on_update: Callable[[List], None], on_error: Callable[[str], None]): + def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]): with self.lock: init_finished_event = threading.Event() init_error_event = threading.Event() init_updated_event = threading.Event() def _on_update(data): - flags = json.loads(data) + response_json = json.loads(data) + flags = EvaluationFlag.schema().load(response_json, many=True) if init_finished_event.is_set(): on_update(flags) else: @@ -210,7 +212,7 @@ def _on_error(data): init_finished_event.set() on_error(data) - t = threading.Thread(target=lambda: self.eventsource.start(_on_update, _on_error)) + t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error]) t.start() init_finished_event.wait(self.conn_timeout_millis / 1000) if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set(): diff --git a/src/amplitude_experiment/flag/flag_config_updater.py b/src/amplitude_experiment/flag/flag_config_updater.py index 646037b..1c6b045 100644 --- a/src/amplitude_experiment/flag/flag_config_updater.py +++ b/src/amplitude_experiment/flag/flag_config_updater.py @@ -84,7 +84,7 @@ def _delete_unused_cohorts(self): self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) -class FlagConfigPoller(FlagConfigUpdaterBase): +class FlagConfigPoller(FlagConfigUpdaterBase, FlagConfigUpdater): def __init__(self, flag_config_api: FlagConfigApi, flag_config_storage: FlagConfigStorage, cohort_loader: CohortLoader, cohort_storage: CohortStorage, config: LocalEvaluationConfig, @@ -103,6 +103,7 @@ def start(self, on_error: Optional[Callable[[str], None]]): self.__update_flag_configs() except Exception as e: self.logger.warning(f"Error while updating flags: {e}") + raise e self.on_error = on_error self.flag_poller.start() @@ -128,7 +129,7 @@ def __update_flag_configs(self): super().update(flag_configs) -class FlagConfigStreamer(FlagConfigUpdaterBase): +class FlagConfigStreamer(FlagConfigUpdaterBase, FlagConfigUpdater): def __init__(self, flag_config_stream_api: FlagConfigStreamApi, flag_config_storage: FlagConfigStorage, cohort_loader: CohortLoader, cohort_storage: CohortStorage, @@ -151,7 +152,7 @@ def stop(self): class FlagConfigUpdaterFallbackRetryWrapper(FlagConfigUpdater): - def __init__(self, main_updater: FlagConfigUpdater, fallback_updater: FlagConfigUpdater, + def __init__(self, main_updater: FlagConfigUpdater, fallback_updater: Optional[FlagConfigUpdater], retry_delay_millis: int, max_jitter_millis: int, fallback_start_retry_delay_millis: int, fallback_start_retry_max_jitter_millis: int, logger: logging.Logger): @@ -178,13 +179,15 @@ def _fallback_on_error(err: str): def _main_on_error(err: str): self.start_main_retry(_main_on_error) try: - self.fallback_updater.start(_fallback_on_error) + if self.fallback_updater is not None: + self.fallback_updater.start(_fallback_on_error) except: self.start_fallback_retry(_fallback_on_error) try: self.main_updater.start(_main_on_error) - self.fallback_updater.stop() + if self.fallback_updater is not None: + self.fallback_updater.stop() self.stop_main_retry() self.stop_fallback_retry() except Exception as e: @@ -199,7 +202,8 @@ def stop(self): self.main_retry_stopper.set() self.fallback_retry_stopper.set() self.main_updater.stop() - self.fallback_updater.stop() + if self.fallback_updater is not None: + self.fallback_updater.stop() def start_main_retry(self, main_on_error: Callable[[str], None]): with self.lock: @@ -219,7 +223,8 @@ def retry_main(): self.main_updater.start(main_on_error) stopper.set() self.stop_fallback_retry() - self.fallback_updater.stop() + if self.fallback_updater is not None: + self.fallback_updater.stop() break except: pass @@ -233,6 +238,9 @@ def start_fallback_retry(self, fallback_on_error: Callable[[str], None]): if self.fallback_retry_stopper: self.fallback_retry_stopper.set() + if self.fallback_updater is None: + return + stopper = threading.Event() def retry_fallback(): @@ -243,7 +251,8 @@ def retry_fallback(): if stopper.is_set(): break try: - self.fallback_updater.start(fallback_on_error) + if self.fallback_updater is not None: + self.fallback_updater.start(fallback_on_error) stopper.set() break except: diff --git a/src/amplitude_experiment/flag/main.py b/src/amplitude_experiment/flag/main.py deleted file mode 100644 index c874120..0000000 --- a/src/amplitude_experiment/flag/main.py +++ /dev/null @@ -1,28 +0,0 @@ -import json -import logging -import time - -from amplitude_experiment.flag import FlagConfigStreamApi, FlagConfigStreamer -from amplitude_experiment.flag.flag_config_storage import FlagConfigStorage, InMemoryFlagConfigStorage -from amplitude_experiment.flag.flag_config_updater import DummyUpdater, FlagConfigUpdaterFallbackRetryWrapper - -if __name__ == '__main__': - logger = logging.Logger("a") - api = FlagConfigStreamApi('server-tUTqR62DZefq7c73zMpbIr1M5VDtwY8T', 'https://skylab-stream.stag2.amplitude.com', 1500, 1000 * 5, 0) - storage = InMemoryFlagConfigStorage() - # streamer = FlagConfigStreamer(api, storage, None, None, logger) - - dummy1 = DummyUpdater("dummy 1") - # dummy1.start_fail = True - dummy1.fail = True - dummy2 = DummyUpdater("dummy 2") - dummy2.start_fail = True - streamer = FlagConfigUpdaterFallbackRetryWrapper(dummy1, dummy2, logger) - - print("start") - streamer.start(print) - # print(storage.get_flag_configs()) - time.sleep(20) - - streamer.stop() - print("done") \ No newline at end of file diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index f64a332..337a49a 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -3,6 +3,8 @@ from unittest.mock import patch import logging +from amplitude_experiment.evaluation.types import EvaluationFlag + from src.amplitude_experiment import LocalEvaluationConfig from src.amplitude_experiment.cohort.cohort_loader import CohortLoader from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig @@ -44,6 +46,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): runner = DeploymentRunner( LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), flag_api, + None, flag_config_storage, cohort_storage, logger, @@ -63,15 +66,15 @@ def test_start_does_not_throw_if_cohort_load_fails(self): cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), - flag_api, flag_config_storage, + flag_api, None, flag_config_storage, cohort_storage, logger, cohort_loader, ) # Mock methods as needed - with patch.object(runner, '_delete_unused_cohorts'): - flag_api.get_flag_configs.return_value = [self.flag] + with patch.object(runner.flag_updater.main_updater, '_delete_unused_cohorts'): + flag_api.get_flag_configs.return_value = EvaluationFlag.schema().load([self.flag], many=True) cohort_download_api.get_cohort.side_effect = RuntimeError("test") # Simply call the method and let the test pass if no exception is raised diff --git a/tests/flag/flag_config_api_test.py b/tests/flag/flag_config_api_test.py index 9159d9e..35ca2a4 100644 --- a/tests/flag/flag_config_api_test.py +++ b/tests/flag/flag_config_api_test.py @@ -5,6 +5,8 @@ import unittest from unittest.mock import MagicMock, patch +from amplitude_experiment.evaluation.types import EvaluationFlag + from amplitude_experiment.flag.flag_config_api import FlagConfigStreamApi @@ -16,22 +18,41 @@ def response(code: int, body: dict = None): return mock_response +BARE_FLAG = [{ + "key": "flag", + "variants": {}, + "segments": [ + { + "conditions": [ + [ + { + "selector": ["context", "user", "cohort_ids"], + "op": "set contains any", + "values": ["COHORT_ID"], + } + ] + ], + } + ] +}] + + class FlagConfigStreamApiTest(unittest.TestCase): def setUp(self) -> None: self.api = FlagConfigStreamApi("deployment_key", "server_url", 2000, 5000, 0) self.success_count = 0 self.error_count = 0 def on_success(self, data): - self.success_count += (1 if data == "apple" else 0) + self.success_count += (1 if data[0].key == "flag" else 0) def on_error(self, data): self.error_count += 1 def test_connect_and_get_data_success(self): with patch.object(self.api.eventsource, 'start') as es: assert self.success_count == 0 - threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() + threading.Thread(target=self.api.start, args=[self.on_success, self.on_error]).start() time.sleep(1) - es.call_args[0][0]('"apple"') + es.call_args[0][0](json.dumps(BARE_FLAG)) assert self.success_count == 1 assert self.error_count == 0 @@ -60,7 +81,7 @@ def test_connect_success_but_error_later(self): assert self.error_count == 0 threading.Thread(target=lambda: self.api.start(self.on_success, self.on_error)).start() time.sleep(1) - es.call_args[0][0]('"apple"') + es.call_args[0][0](json.dumps(BARE_FLAG)) assert self.success_count == 1 assert self.error_count == 0 es.call_args[0][1]('error') diff --git a/tests/flag/flag_config_updater_test.py b/tests/flag/flag_config_updater_test.py index d93b7e6..b9329d6 100644 --- a/tests/flag/flag_config_updater_test.py +++ b/tests/flag/flag_config_updater_test.py @@ -102,10 +102,14 @@ def test_main_start_failed_fallback_start_success_main_start_success_on_retry(se self.api.start(None) assert self.dummy1.start_count == 1 assert self.dummy2.start_count == 1 - self.dummy1.fail_time = 0 + self.dummy1.fail_time = -1 time.sleep(1.1) assert self.dummy1.start_count == 2 assert self.dummy2.start_count == 1 + # No more restarts + time.sleep(2) + assert self.dummy1.start_count == 2 + assert self.dummy2.start_count == 1 def test_main_start_failed_fallback_start_failed(self): self.dummy1.fail_time = 0 @@ -219,4 +223,4 @@ def test_main_start_success_later_failed_fallback_start_failed_later_stopped(sel # Now main failed assert self.dummy1.start_count == 1 assert self.dummy2.start_count == 1 - + self.api.stop() From c8c20df44384328adb301001b7c52d4110480297 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 21 Dec 2024 06:41:16 -0800 Subject: [PATCH 3/5] Remove unused imports --- src/amplitude_experiment/deployment/deployment_runner.py | 2 +- src/amplitude_experiment/flag/flag_config_api.py | 6 ++---- src/amplitude_experiment/flag/flag_config_updater.py | 2 +- tests/flag/flag_config_api_test.py | 3 --- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 000af7f..6532450 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -9,7 +9,7 @@ from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller -from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags +from ..util.flag_config import get_all_cohort_ids_from_flags DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000 DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000 diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index 9411546..b8f0a40 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -1,13 +1,11 @@ import json -import random import threading -import time from http.client import HTTPResponse, HTTPConnection, HTTPSConnection -from typing import List, Tuple, Optional, Callable, Mapping +from typing import List, Optional, Callable, Mapping import sseclient -from ..connection_pool import HTTPConnectionPool, WrapperHTTPConnection +from ..connection_pool import HTTPConnectionPool from ..util.updater import get_duration_with_jitter from ..evaluation.types import EvaluationFlag from ..version import __version__ diff --git a/src/amplitude_experiment/flag/flag_config_updater.py b/src/amplitude_experiment/flag/flag_config_updater.py index 1c6b045..232c75b 100644 --- a/src/amplitude_experiment/flag/flag_config_updater.py +++ b/src/amplitude_experiment/flag/flag_config_updater.py @@ -10,7 +10,7 @@ from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller from ..cohort.cohort_loader import CohortLoader -from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags +from ..util.flag_config import get_all_cohort_ids_from_flag from ..util.updater import get_duration_with_jitter diff --git a/tests/flag/flag_config_api_test.py b/tests/flag/flag_config_api_test.py index 35ca2a4..e6cef97 100644 --- a/tests/flag/flag_config_api_test.py +++ b/tests/flag/flag_config_api_test.py @@ -1,12 +1,9 @@ import json -import logging import threading import time import unittest from unittest.mock import MagicMock, patch -from amplitude_experiment.evaluation.types import EvaluationFlag - from amplitude_experiment.flag.flag_config_api import FlagConfigStreamApi From 12d8bfb090be91e7523129531454ece46253f8c4 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 21 Dec 2024 06:43:35 -0800 Subject: [PATCH 4/5] Fix union type --- src/amplitude_experiment/flag/flag_config_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index b8f0a40..b4b2e59 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -1,7 +1,7 @@ import json import threading from http.client import HTTPResponse, HTTPConnection, HTTPSConnection -from typing import List, Optional, Callable, Mapping +from typing import List, Optional, Callable, Mapping, Union import sseclient @@ -146,7 +146,7 @@ def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None] return on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e)) - def _get_conn(self) -> tuple[HTTPConnection | HTTPSConnection, HTTPResponse]: + def _get_conn(self) -> tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]: scheme, _, host = self.server_url.split('/', 3) connection = HTTPConnection if scheme == 'http:' else HTTPSConnection From f4851fdda565eb880f20102cbbde24f2bf377b81 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 21 Dec 2024 06:46:17 -0800 Subject: [PATCH 5/5] Fix typing --- src/amplitude_experiment/flag/flag_config_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index b4b2e59..662f3cf 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -1,7 +1,7 @@ import json import threading from http.client import HTTPResponse, HTTPConnection, HTTPSConnection -from typing import List, Optional, Callable, Mapping, Union +from typing import List, Optional, Callable, Mapping, Union, Tuple import sseclient @@ -146,7 +146,7 @@ def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None] return on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e)) - def _get_conn(self) -> tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]: + def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]: scheme, _, host = self.server_url.split('/', 3) connection = HTTPConnection if scheme == 'http:' else HTTPSConnection