Skip to content

Commit

Permalink
[GPT] enable historical signals fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
GuillaumeDSM committed Oct 27, 2023
1 parent dc2de1e commit 505eae5
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 18 deletions.
70 changes: 55 additions & 15 deletions Evaluator/TA/ai_evaluator/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@


class GPTEvaluator(evaluators.TAEvaluator):
GLOBAL_VERSION = 1
PREPROMPT = "Predict: {up or down} {confidence%} (no other information)"
PASSED_DATA_LEN = 10
MAX_CONFIDENCE_PERCENT = 100
HIGH_CONFIDENCE_PERCENT = 80
MEDIUM_CONFIDENCE_PERCENT = 50
LOW_CONFIDENCE_PERCENT = 30
Expand All @@ -53,6 +55,7 @@ def __init__(self, tentacles_setup_config):
self.indicator = None
self.source = None
self.period = None
self.max_confidence_threshold = 0
self.gpt_model = gpt_service.GPTService.DEFAULT_MODEL
self.is_backtesting = False
self.min_allowed_timeframe = os.getenv("MIN_GPT_TIMEFRAME", None)
Expand All @@ -65,13 +68,21 @@ def __init__(self, tentacles_setup_config):
except ValueError:
self.logger.error(f"Invalid timeframe configuration: unknown timeframe: '{self.min_allowed_timeframe}'")
self.allow_reevaluations = os_util.parse_boolean_environment_var("ALLOW_GPT_REEVALUATIONS", "True")
self.services_config = None

def enable_reevaluation(self) -> bool:
"""
Override when artificial re-evaluations from the evaluator channel can be disabled
"""
return self.allow_reevaluations

@classmethod
def get_signals_history_type(cls):
"""
Override when this evaluator uses a specific type of signal history
"""
return commons_enums.SignalHistoryTypes.GPT

async def load_and_save_user_inputs(self, bot_id: str) -> dict:
"""
instance method API for user inputs
Expand Down Expand Up @@ -100,6 +111,11 @@ def init_user_inputs(self, inputs: dict) -> None:
self.period, inputs, min_val=1,
title="Period: length of the indicator period."
)
self.max_confidence_threshold = self.UI.user_input(
"max_confidence_threshold", enums.UserInputTypes.INT,
self.max_confidence_threshold, inputs, min_val=0, max_val=100,
title="Maximum confidence threshold: % confidence value starting from which to return 1 or -1."
)
if len(self.GPT_MODELS) > 1 and self.enable_model_selector:
self.gpt_model = self.UI.user_input(
"GPT model", enums.UserInputTypes.OPTIONS, gpt_service.GPTService.DEFAULT_MODEL,
Expand All @@ -112,7 +128,9 @@ async def _init_GPT_models(self):
self.GPT_MODELS = [gpt_service.GPTService.DEFAULT_MODEL]
if self.enable_model_selector and not self.is_backtesting:
try:
service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting)
service = await services_api.get_service(
gpt_service.GPTService, self.is_backtesting, self.services_config
)
self.GPT_MODELS = service.models
except Exception as err:
self.logger.exception(err, True, f"Impossible to fetch GPT models: {err}")
Expand All @@ -138,13 +156,14 @@ async def evaluate(self, cryptocurrency, symbol, time_frame, candle_data, candle
self.eval_note = commons_constants.START_PENDING_EVAL_NOTE
if self._check_timeframe(time_frame):
try:
candle_time = candle[commons_enums.PriceIndexes.IND_PRICE_TIME.value]
computed_data = self.call_indicator(candle_data)
reduced_data = computed_data[-self.PASSED_DATA_LEN:]
formatted_data = ", ".join(str(datum).replace('[', '').replace(']', '') for datum in reduced_data)
prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame)
prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame, candle_time)
cleaned_prediction = prediction.strip().replace("\n", "").replace(".", "").lower()
prediction_side = self._parse_prediction_side(cleaned_prediction)
if prediction_side == 0:
if prediction_side == 0 and not self.is_backtesting:
self.logger.error(f"Error when reading GPT answer: {cleaned_prediction}")
return
confidence = self._parse_confidence(cleaned_prediction) / 100
Expand All @@ -171,20 +190,35 @@ async def evaluate(self, cryptocurrency, symbol, time_frame, candle_data, candle
eval_time=evaluators_util.get_eval_time(full_candle=candle,
time_frame=time_frame))

async def ask_gpt(self, preprompt, inputs, symbol, time_frame) -> str:
async def ask_gpt(self, preprompt, inputs, symbol, time_frame, candle_time) -> str:
try:
service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting)
service = await services_api.get_service(
gpt_service.GPTService,
self.is_backtesting,
{} if self.is_backtesting else self.services_config
)
resp = await service.get_chat_completion(
[
service.create_message("system", preprompt),
service.create_message("user", inputs),
],
model=self.gpt_model if self.enable_model_selector else None
model=self.gpt_model if self.enable_model_selector else None,
exchange=self.exchange_name,
symbol=symbol,
time_frame=time_frame,
version=self.get_version(),
candle_open_time=candle_time,
use_stored_signals=self.is_backtesting
)
self.logger.info(f"GPT's answer is '{resp}' for {symbol} on {time_frame} with input: {inputs}")
return resp
except services_errors.CreationError as err:
raise evaluators_errors.UnavailableEvaluatorError(f"Impossible to get ChatGPT prediction: {err}") from err
except Exception as err:
print(err)

def get_version(self):
return f"{self.gpt_model}-{self.source}-{self.indicator}-{self.period}-{self.GLOBAL_VERSION}"

def call_indicator(self, candle_data):
return data_util.drop_nan(self.INDICATORS[self.indicator](candle_data, self.period))
Expand Down Expand Up @@ -216,14 +250,20 @@ def _parse_confidence(self, cleaned_prediction):
up with 70% confidence
up with high confidence
"""
value = self.LOW_CONFIDENCE_PERCENT
if "%" in cleaned_prediction:
percent_index = cleaned_prediction.index("%")
return float(cleaned_prediction[:percent_index].split(" ")[-1])
if "high" in cleaned_prediction:
return self.HIGH_CONFIDENCE_PERCENT
if "medium" in cleaned_prediction or "intermediate" in cleaned_prediction:
return self.MEDIUM_CONFIDENCE_PERCENT
if "low" in cleaned_prediction:
return self.LOW_CONFIDENCE_PERCENT
self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence")
return self.LOW_CONFIDENCE_PERCENT
value = float(cleaned_prediction[:percent_index].split(" ")[-1])
elif "high" in cleaned_prediction:
value = self.HIGH_CONFIDENCE_PERCENT
elif "medium" in cleaned_prediction or "intermediate" in cleaned_prediction:
value = self.MEDIUM_CONFIDENCE_PERCENT
elif "low" in cleaned_prediction:
value = self.LOW_CONFIDENCE_PERCENT
elif not cleaned_prediction:
value = 0
else:
self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence")
if value >= self.max_confidence_threshold:
return self.MAX_CONFIDENCE_PERCENT
return value
152 changes: 149 additions & 3 deletions Services/Services_bases/gpt_service/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.
import asyncio
import os
import openai
import logging
Expand All @@ -21,10 +22,19 @@
import octobot_services.constants as services_constants
import octobot_services.services as services
import octobot_services.errors as errors

import octobot_commons.enums as commons_enums
import octobot_commons.constants as commons_constants
import octobot_commons.os_util as os_util
import octobot_commons.authentication as authentication
import octobot_commons.tree as tree

import octobot.constants as constants
import octobot.community as community


class GPTService(services.AbstractService):
BACKTESTING_ENABLED = True
DEFAULT_MODEL = "gpt-3.5-turbo"

def get_fields_description(self):
Expand All @@ -46,6 +56,7 @@ def __init__(self):
logging.getLogger("openai").setLevel(logging.WARNING)
self._env_secret_key = os.getenv(services_constants.ENV_OPENAI_SECRET_KEY, None)
self.model = os.getenv(services_constants.ENV_GPT_MODEL, self.DEFAULT_MODEL)
self.stored_signals: tree.BaseTree = tree.BaseTree()
self.models = []
self.daily_tokens_limit = int(os.getenv(services_constants.ENV_GPT_DAILY_TOKENS_LIMIT, 0))
self.consumed_daily_tokens = 1
Expand All @@ -63,6 +74,27 @@ async def get_chat_completion(
n=1,
stop=None,
temperature=0.5,
exchange: str = None,
symbol: str = None,
time_frame: str = None,
version: str = None,
candle_open_time: float = None,
use_stored_signals: bool = False,
) -> str:
if use_stored_signals:
return self._get_signal_from_stored_signals(exchange, symbol, time_frame, version, candle_open_time)
if self.use_stored_signals_only():
return await self._fetch_signal_from_stored_signals(exchange, symbol, time_frame, version, candle_open_time)
return await self._get_signal_from_gpt(messages, model, max_tokens, n, stop, temperature)

async def _get_signal_from_gpt(
self,
messages,
model=None,
max_tokens=3000,
n=1,
stop=None,
temperature=0.5
):
self._ensure_rate_limit()
try:
Expand All @@ -87,6 +119,111 @@ async def get_chat_completion(
f"Unexpected error when running request with model {model}: {err}"
) from err

def _get_signal_from_stored_signals(
self,
exchange: str,
symbol: str,
time_frame: str,
version: str,
candle_open_time: float,
):
try:
return self.stored_signals.get_node([exchange, symbol, time_frame, version, candle_open_time]).node_value
except tree.NodeExistsError:
return ""

async def _fetch_signal_from_stored_signals(
self,
exchange: str,
symbol: str,
time_frame: str,
version: str,
candle_open_time: float,
) -> str:
authenticator = authentication.Authenticator.instance()
try:
return await authenticator.get_gpt_signal(
exchange, symbol, commons_enums.TimeFrames(time_frame), candle_open_time, version
)
except Exception as err:
self.logger.exception(err, True, f"Error when fetching gpt signal: {err}")

def store_signal_history(
self,
exchange: str,
symbol: str,
time_frame: commons_enums.TimeFrames,
version: str,
signals_by_candle_open_time,
):
tf = time_frame.value
for candle_open_time, signal in signals_by_candle_open_time.items():
self.stored_signals.set_node_at_path(
signal,
str,
[exchange, symbol, tf, version, candle_open_time]
)

def has_signal_history(
self,
exchange: str,
symbol: str,
time_frame: commons_enums.TimeFrames,
min_timestamp: float,
max_timestamp: float,
version: str
):
for ts in (min_timestamp, max_timestamp):
if self._get_signal_from_stored_signals(
exchange, symbol, time_frame.value, version, self._get_open_candle_timestamp(time_frame, ts)
) == "":
return False
return True

async def _fetch_and_store_history(
self, authenticator, exchange_name, symbol, time_frame, version, min_timestamp: float, max_timestamp: float
):
signals_by_candle_open_time = await authenticator.get_gpt_signals_history(
exchange_name, symbol, time_frame,
self._get_open_candle_timestamp(time_frame, min_timestamp),
self._get_open_candle_timestamp(time_frame, max_timestamp),
version
)
if not signals_by_candle_open_time:
self.logger.error(
f"No ChatGPT signal history for {symbol} on {time_frame.value} for {exchange_name} with {version}. "
f"Please check {self._supported_history_url()} to get the list of supported signals history."
)
self.store_signal_history(
exchange_name, symbol, time_frame, version, signals_by_candle_open_time
)

async def fetch_gpt_history(
self, exchange_name: str, symbols: list, time_frames: list,
version: str, start_timestamp: float, end_timestamp: float
):
authenticator = authentication.Authenticator.instance()
coros = [
self._fetch_and_store_history(
authenticator, exchange_name, symbol, time_frame, version, start_timestamp, end_timestamp
)
for symbol in symbols
for time_frame in time_frames
if not self.has_signal_history(exchange_name, symbol, time_frame, start_timestamp, end_timestamp, version)
]
if coros:
await asyncio.gather(*coros)

def _get_open_candle_timestamp(self, time_frame: commons_enums.TimeFrames, base_timestamp: float):
tf_seconds = commons_enums.TimeFramesMinutes[time_frame] * commons_constants.MINUTE_TO_SECONDS
return base_timestamp - (base_timestamp % tf_seconds)

def clear_signal_history(self):
self.stored_signals.clear()

def _supported_history_url(self):
return f"{community.IdentifiersProvider.COMMUNITY_LANDING_URL}/chat-gpt-trading"

def _ensure_rate_limit(self):
if self.last_consumed_token_date != datetime.date.today():
self.consumed_daily_tokens = 0
Expand All @@ -101,7 +238,7 @@ def _update_token_usage(self, consumed_tokens):
self.logger.debug(f"Consumed {consumed_tokens} tokens. {self.consumed_daily_tokens} consumed tokens today.")

def check_required_config(self, config):
if self._env_secret_key is not None:
if self._env_secret_key is not None or self.use_stored_signals_only():
return True
try:
return bool(config[services_constants.CONIG_OPENAI_SECRET_KEY])
Expand All @@ -110,6 +247,8 @@ def check_required_config(self, config):

def has_required_configuration(self):
try:
if self.use_stored_signals_only():
return True
return self.check_required_config(
self.config[services_constants.CONFIG_CATEGORY_SERVICES].get(services_constants.CONFIG_GPT, {})
)
Expand Down Expand Up @@ -140,6 +279,9 @@ def _get_api_key(self):

async def prepare(self) -> None:
try:
if self.use_stored_signals_only():
self.logger.info(f"Skipping models fetch as self.use_stored_signals_only() is True")
return
fetched_models = await openai.Model.alist(api_key=self._get_api_key())
self.models = [d["id"] for d in fetched_models["data"]]
if self.model not in self.models:
Expand All @@ -151,11 +293,15 @@ async def prepare(self) -> None:
self.logger.error(f"Unexpected error when checking api key: {err}")

def _is_healthy(self):
return self._get_api_key() and self.models
return self.use_stored_signals_only() or (self._get_api_key() and self.models)

def get_successful_startup_message(self):
return f"GPT configured and ready. {len(self.models)} AI models are available. Using {self.model}.", \
return f"GPT configured and ready. {len(self.models)} AI models are available. " \
f"Using {'stored signals' if self.use_stored_signals_only() else self.models}.", \
self._is_healthy()

def use_stored_signals_only(self):
return not self.config

async def stop(self):
pass

0 comments on commit 505eae5

Please sign in to comment.