From 505eae540467c0e1ee549eddde436acfdb6bf889 Mon Sep 17 00:00:00 2001 From: Guillaume De Saint Martin Date: Wed, 25 Oct 2023 23:26:21 +0200 Subject: [PATCH] [GPT] enable historical signals fetch --- Evaluator/TA/ai_evaluator/ai.py | 70 ++++++++-- Services/Services_bases/gpt_service/gpt.py | 152 ++++++++++++++++++++- 2 files changed, 204 insertions(+), 18 deletions(-) diff --git a/Evaluator/TA/ai_evaluator/ai.py b/Evaluator/TA/ai_evaluator/ai.py index 750c58351..f5370cda5 100644 --- a/Evaluator/TA/ai_evaluator/ai.py +++ b/Evaluator/TA/ai_evaluator/ai.py @@ -31,8 +31,10 @@ class GPTEvaluator(evaluators.TAEvaluator): + GLOBAL_VERSION = 1 PREPROMPT = "Predict: {up or down} {confidence%} (no other information)" PASSED_DATA_LEN = 10 + MAX_CONFIDENCE_PERCENT = 100 HIGH_CONFIDENCE_PERCENT = 80 MEDIUM_CONFIDENCE_PERCENT = 50 LOW_CONFIDENCE_PERCENT = 30 @@ -53,6 +55,7 @@ def __init__(self, tentacles_setup_config): self.indicator = None self.source = None self.period = None + self.max_confidence_threshold = 0 self.gpt_model = gpt_service.GPTService.DEFAULT_MODEL self.is_backtesting = False self.min_allowed_timeframe = os.getenv("MIN_GPT_TIMEFRAME", None) @@ -65,6 +68,7 @@ def __init__(self, tentacles_setup_config): except ValueError: self.logger.error(f"Invalid timeframe configuration: unknown timeframe: '{self.min_allowed_timeframe}'") self.allow_reevaluations = os_util.parse_boolean_environment_var("ALLOW_GPT_REEVALUATIONS", "True") + self.services_config = None def enable_reevaluation(self) -> bool: """ @@ -72,6 +76,13 @@ def enable_reevaluation(self) -> bool: """ return self.allow_reevaluations + @classmethod + def get_signals_history_type(cls): + """ + Override when this evaluator uses a specific type of signal history + """ + return commons_enums.SignalHistoryTypes.GPT + async def load_and_save_user_inputs(self, bot_id: str) -> dict: """ instance method API for user inputs @@ -100,6 +111,11 @@ def init_user_inputs(self, inputs: dict) -> None: self.period, inputs, min_val=1, title="Period: length of the indicator period." ) + self.max_confidence_threshold = self.UI.user_input( + "max_confidence_threshold", enums.UserInputTypes.INT, + self.max_confidence_threshold, inputs, min_val=0, max_val=100, + title="Maximum confidence threshold: % confidence value starting from which to return 1 or -1." + ) if len(self.GPT_MODELS) > 1 and self.enable_model_selector: self.gpt_model = self.UI.user_input( "GPT model", enums.UserInputTypes.OPTIONS, gpt_service.GPTService.DEFAULT_MODEL, @@ -112,7 +128,9 @@ async def _init_GPT_models(self): self.GPT_MODELS = [gpt_service.GPTService.DEFAULT_MODEL] if self.enable_model_selector and not self.is_backtesting: try: - service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting) + service = await services_api.get_service( + gpt_service.GPTService, self.is_backtesting, self.services_config + ) self.GPT_MODELS = service.models except Exception as err: self.logger.exception(err, True, f"Impossible to fetch GPT models: {err}") @@ -138,13 +156,14 @@ async def evaluate(self, cryptocurrency, symbol, time_frame, candle_data, candle self.eval_note = commons_constants.START_PENDING_EVAL_NOTE if self._check_timeframe(time_frame): try: + candle_time = candle[commons_enums.PriceIndexes.IND_PRICE_TIME.value] computed_data = self.call_indicator(candle_data) reduced_data = computed_data[-self.PASSED_DATA_LEN:] formatted_data = ", ".join(str(datum).replace('[', '').replace(']', '') for datum in reduced_data) - prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame) + prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame, candle_time) cleaned_prediction = prediction.strip().replace("\n", "").replace(".", "").lower() prediction_side = self._parse_prediction_side(cleaned_prediction) - if prediction_side == 0: + if prediction_side == 0 and not self.is_backtesting: self.logger.error(f"Error when reading GPT answer: {cleaned_prediction}") return confidence = self._parse_confidence(cleaned_prediction) / 100 @@ -171,20 +190,35 @@ async def evaluate(self, cryptocurrency, symbol, time_frame, candle_data, candle eval_time=evaluators_util.get_eval_time(full_candle=candle, time_frame=time_frame)) - async def ask_gpt(self, preprompt, inputs, symbol, time_frame) -> str: + async def ask_gpt(self, preprompt, inputs, symbol, time_frame, candle_time) -> str: try: - service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting) + service = await services_api.get_service( + gpt_service.GPTService, + self.is_backtesting, + {} if self.is_backtesting else self.services_config + ) resp = await service.get_chat_completion( [ service.create_message("system", preprompt), service.create_message("user", inputs), ], - model=self.gpt_model if self.enable_model_selector else None + model=self.gpt_model if self.enable_model_selector else None, + exchange=self.exchange_name, + symbol=symbol, + time_frame=time_frame, + version=self.get_version(), + candle_open_time=candle_time, + use_stored_signals=self.is_backtesting ) self.logger.info(f"GPT's answer is '{resp}' for {symbol} on {time_frame} with input: {inputs}") return resp except services_errors.CreationError as err: raise evaluators_errors.UnavailableEvaluatorError(f"Impossible to get ChatGPT prediction: {err}") from err + except Exception as err: + print(err) + + def get_version(self): + return f"{self.gpt_model}-{self.source}-{self.indicator}-{self.period}-{self.GLOBAL_VERSION}" def call_indicator(self, candle_data): return data_util.drop_nan(self.INDICATORS[self.indicator](candle_data, self.period)) @@ -216,14 +250,20 @@ def _parse_confidence(self, cleaned_prediction): up with 70% confidence up with high confidence """ + value = self.LOW_CONFIDENCE_PERCENT if "%" in cleaned_prediction: percent_index = cleaned_prediction.index("%") - return float(cleaned_prediction[:percent_index].split(" ")[-1]) - if "high" in cleaned_prediction: - return self.HIGH_CONFIDENCE_PERCENT - if "medium" in cleaned_prediction or "intermediate" in cleaned_prediction: - return self.MEDIUM_CONFIDENCE_PERCENT - if "low" in cleaned_prediction: - return self.LOW_CONFIDENCE_PERCENT - self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence") - return self.LOW_CONFIDENCE_PERCENT + value = float(cleaned_prediction[:percent_index].split(" ")[-1]) + elif "high" in cleaned_prediction: + value = self.HIGH_CONFIDENCE_PERCENT + elif "medium" in cleaned_prediction or "intermediate" in cleaned_prediction: + value = self.MEDIUM_CONFIDENCE_PERCENT + elif "low" in cleaned_prediction: + value = self.LOW_CONFIDENCE_PERCENT + elif not cleaned_prediction: + value = 0 + else: + self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence") + if value >= self.max_confidence_threshold: + return self.MAX_CONFIDENCE_PERCENT + return value diff --git a/Services/Services_bases/gpt_service/gpt.py b/Services/Services_bases/gpt_service/gpt.py index 02549b29b..2385342ac 100644 --- a/Services/Services_bases/gpt_service/gpt.py +++ b/Services/Services_bases/gpt_service/gpt.py @@ -13,6 +13,7 @@ # # You should have received a copy of the GNU Lesser General Public # License along with this library. +import asyncio import os import openai import logging @@ -21,10 +22,19 @@ import octobot_services.constants as services_constants import octobot_services.services as services import octobot_services.errors as errors + +import octobot_commons.enums as commons_enums +import octobot_commons.constants as commons_constants +import octobot_commons.os_util as os_util +import octobot_commons.authentication as authentication +import octobot_commons.tree as tree + import octobot.constants as constants +import octobot.community as community class GPTService(services.AbstractService): + BACKTESTING_ENABLED = True DEFAULT_MODEL = "gpt-3.5-turbo" def get_fields_description(self): @@ -46,6 +56,7 @@ def __init__(self): logging.getLogger("openai").setLevel(logging.WARNING) self._env_secret_key = os.getenv(services_constants.ENV_OPENAI_SECRET_KEY, None) self.model = os.getenv(services_constants.ENV_GPT_MODEL, self.DEFAULT_MODEL) + self.stored_signals: tree.BaseTree = tree.BaseTree() self.models = [] self.daily_tokens_limit = int(os.getenv(services_constants.ENV_GPT_DAILY_TOKENS_LIMIT, 0)) self.consumed_daily_tokens = 1 @@ -63,6 +74,27 @@ async def get_chat_completion( n=1, stop=None, temperature=0.5, + exchange: str = None, + symbol: str = None, + time_frame: str = None, + version: str = None, + candle_open_time: float = None, + use_stored_signals: bool = False, + ) -> str: + if use_stored_signals: + return self._get_signal_from_stored_signals(exchange, symbol, time_frame, version, candle_open_time) + if self.use_stored_signals_only(): + return await self._fetch_signal_from_stored_signals(exchange, symbol, time_frame, version, candle_open_time) + return await self._get_signal_from_gpt(messages, model, max_tokens, n, stop, temperature) + + async def _get_signal_from_gpt( + self, + messages, + model=None, + max_tokens=3000, + n=1, + stop=None, + temperature=0.5 ): self._ensure_rate_limit() try: @@ -87,6 +119,111 @@ async def get_chat_completion( f"Unexpected error when running request with model {model}: {err}" ) from err + def _get_signal_from_stored_signals( + self, + exchange: str, + symbol: str, + time_frame: str, + version: str, + candle_open_time: float, + ): + try: + return self.stored_signals.get_node([exchange, symbol, time_frame, version, candle_open_time]).node_value + except tree.NodeExistsError: + return "" + + async def _fetch_signal_from_stored_signals( + self, + exchange: str, + symbol: str, + time_frame: str, + version: str, + candle_open_time: float, + ) -> str: + authenticator = authentication.Authenticator.instance() + try: + return await authenticator.get_gpt_signal( + exchange, symbol, commons_enums.TimeFrames(time_frame), candle_open_time, version + ) + except Exception as err: + self.logger.exception(err, True, f"Error when fetching gpt signal: {err}") + + def store_signal_history( + self, + exchange: str, + symbol: str, + time_frame: commons_enums.TimeFrames, + version: str, + signals_by_candle_open_time, + ): + tf = time_frame.value + for candle_open_time, signal in signals_by_candle_open_time.items(): + self.stored_signals.set_node_at_path( + signal, + str, + [exchange, symbol, tf, version, candle_open_time] + ) + + def has_signal_history( + self, + exchange: str, + symbol: str, + time_frame: commons_enums.TimeFrames, + min_timestamp: float, + max_timestamp: float, + version: str + ): + for ts in (min_timestamp, max_timestamp): + if self._get_signal_from_stored_signals( + exchange, symbol, time_frame.value, version, self._get_open_candle_timestamp(time_frame, ts) + ) == "": + return False + return True + + async def _fetch_and_store_history( + self, authenticator, exchange_name, symbol, time_frame, version, min_timestamp: float, max_timestamp: float + ): + signals_by_candle_open_time = await authenticator.get_gpt_signals_history( + exchange_name, symbol, time_frame, + self._get_open_candle_timestamp(time_frame, min_timestamp), + self._get_open_candle_timestamp(time_frame, max_timestamp), + version + ) + if not signals_by_candle_open_time: + self.logger.error( + f"No ChatGPT signal history for {symbol} on {time_frame.value} for {exchange_name} with {version}. " + f"Please check {self._supported_history_url()} to get the list of supported signals history." + ) + self.store_signal_history( + exchange_name, symbol, time_frame, version, signals_by_candle_open_time + ) + + async def fetch_gpt_history( + self, exchange_name: str, symbols: list, time_frames: list, + version: str, start_timestamp: float, end_timestamp: float + ): + authenticator = authentication.Authenticator.instance() + coros = [ + self._fetch_and_store_history( + authenticator, exchange_name, symbol, time_frame, version, start_timestamp, end_timestamp + ) + for symbol in symbols + for time_frame in time_frames + if not self.has_signal_history(exchange_name, symbol, time_frame, start_timestamp, end_timestamp, version) + ] + if coros: + await asyncio.gather(*coros) + + def _get_open_candle_timestamp(self, time_frame: commons_enums.TimeFrames, base_timestamp: float): + tf_seconds = commons_enums.TimeFramesMinutes[time_frame] * commons_constants.MINUTE_TO_SECONDS + return base_timestamp - (base_timestamp % tf_seconds) + + def clear_signal_history(self): + self.stored_signals.clear() + + def _supported_history_url(self): + return f"{community.IdentifiersProvider.COMMUNITY_LANDING_URL}/chat-gpt-trading" + def _ensure_rate_limit(self): if self.last_consumed_token_date != datetime.date.today(): self.consumed_daily_tokens = 0 @@ -101,7 +238,7 @@ def _update_token_usage(self, consumed_tokens): self.logger.debug(f"Consumed {consumed_tokens} tokens. {self.consumed_daily_tokens} consumed tokens today.") def check_required_config(self, config): - if self._env_secret_key is not None: + if self._env_secret_key is not None or self.use_stored_signals_only(): return True try: return bool(config[services_constants.CONIG_OPENAI_SECRET_KEY]) @@ -110,6 +247,8 @@ def check_required_config(self, config): def has_required_configuration(self): try: + if self.use_stored_signals_only(): + return True return self.check_required_config( self.config[services_constants.CONFIG_CATEGORY_SERVICES].get(services_constants.CONFIG_GPT, {}) ) @@ -140,6 +279,9 @@ def _get_api_key(self): async def prepare(self) -> None: try: + if self.use_stored_signals_only(): + self.logger.info(f"Skipping models fetch as self.use_stored_signals_only() is True") + return fetched_models = await openai.Model.alist(api_key=self._get_api_key()) self.models = [d["id"] for d in fetched_models["data"]] if self.model not in self.models: @@ -151,11 +293,15 @@ async def prepare(self) -> None: self.logger.error(f"Unexpected error when checking api key: {err}") def _is_healthy(self): - return self._get_api_key() and self.models + return self.use_stored_signals_only() or (self._get_api_key() and self.models) def get_successful_startup_message(self): - return f"GPT configured and ready. {len(self.models)} AI models are available. Using {self.model}.", \ + return f"GPT configured and ready. {len(self.models)} AI models are available. " \ + f"Using {'stored signals' if self.use_stored_signals_only() else self.models}.", \ self._is_healthy() + def use_stored_signals_only(self): + return not self.config + async def stop(self): pass