From db32c52a6981c4630a56249380a7eb1a219fa6af Mon Sep 17 00:00:00 2001 From: Nestor Amesty Date: Mon, 11 Mar 2024 14:00:19 +0100 Subject: [PATCH] Persist question and key between runs. Predictable returns reasoning --- .../app.py | 22 +++++++++--- evo_researcher/autonolas/research.py | 2 +- evo_researcher/benchmark/agents.py | 12 ++++--- evo_researcher/functions/evaluate_question.py | 34 ++++++++++++------- evo_researcher/functions/grade_info.py | 2 +- evo_researcher/functions/rerank_results.py | 2 +- evo_researcher/functions/research.py | 2 +- evo_researcher/main.py | 6 ++-- scripts/benchmark.py | 2 +- tests/test_evaluate_question.py | 3 +- 10 files changed, 57 insertions(+), 30 deletions(-) rename scripts/public_agent_app.py => evo_researcher/app.py (77%) diff --git a/scripts/public_agent_app.py b/evo_researcher/app.py similarity index 77% rename from scripts/public_agent_app.py rename to evo_researcher/app.py index 68829231..1c05f910 100644 --- a/scripts/public_agent_app.py +++ b/evo_researcher/app.py @@ -32,18 +32,30 @@ def log(self, msg: str) -> None: st.title("Evo Predict") with st.form("question_form", clear_on_submit=True): - question = st.text_input('Question', placeholder="Will Twitter implement a new misinformation policy before the end of 2024") - openai_api_key = st.text_input('OpenAI API Key', placeholder="sk-...", type="password") + question = st.text_input( + 'Question', + placeholder="Will Twitter implement a new misinformation policy before the end of 2024", + value=st.session_state.get('question', '') + ) + openai_api_key = st.text_input( + 'OpenAI API Key', + placeholder="sk-...", + type="password", + value=st.session_state.get('openai_api_key', '') + ) submit_button = st.form_submit_button('Predict') if submit_button and question and openai_api_key: + st.session_state['openai_api_key'] = openai_api_key + st.session_state['question'] = question + with st.container(): with st.spinner("Evaluating question..."): - is_predictable = evaluate_if_predictable(question=question) + (is_predictable, reasoning) = evaluate_if_predictable(question=question, api_key=openai_api_key) st.container(border=True).markdown(f"""### Question evaluation\n\nQuestion: **{question}**\n\nIs predictable: `{is_predictable}`""") if not is_predictable: - st.container().error("The agent thinks this question is not predictable.") + st.container().error(f"The agent thinks this question is not predictable: \n\n{reasoning}") st.stop() with st.spinner("Researching..."): @@ -57,7 +69,7 @@ def log(self, msg: str) -> None: with st.spinner("Predicting..."): with st.container(border=True): - prediction = _make_prediction(market_question=question, additional_information=report, engine="gpt-4-1106-preview", temperature=0.0) + prediction = _make_prediction(market_question=question, additional_information=report, engine="gpt-4-0125-preview", temperature=0.0, api_key=openai_api_key) with st.container().expander("Show agent's prediction", expanded=False): if prediction.outcome_prediction == None: st.container().error("The agent failed to generate a prediction") diff --git a/evo_researcher/autonolas/research.py b/evo_researcher/autonolas/research.py index 1623cc24..c112a361 100644 --- a/evo_researcher/autonolas/research.py +++ b/evo_researcher/autonolas/research.py @@ -1166,7 +1166,7 @@ def make_prediction( prompt: str, additional_information: str, temperature: float = 0.7, - engine: str = "gpt-3.5-turbo-1106", + engine: str = "gpt-3.5-turbo-0125", api_key: str | None = None ) -> Prediction: if api_key == None: diff --git a/evo_researcher/benchmark/agents.py b/evo_researcher/benchmark/agents.py index 827d3aa1..99ce04fb 100644 --- a/evo_researcher/benchmark/agents.py +++ b/evo_researcher/benchmark/agents.py @@ -107,10 +107,12 @@ def __init__( self.embedding_model = embedding_model def is_predictable(self, market_question: str) -> bool: - return is_predictable(question=market_question) + (result, _) = is_predictable(question=market_question) + return result def is_predictable_restricted(self, market_question: str, time_restriction_up_to: datetime) -> bool: - return is_predictable(question=market_question) + (result, _) = is_predictable(question=market_question) + return result def research(self, market_question: str) -> str: return research_autonolas( @@ -164,10 +166,12 @@ def __init__( self.use_tavily_raw_content = use_tavily_raw_content def is_predictable(self, market_question: str) -> bool: - return is_predictable(question=market_question) + (result, _) = is_predictable(question=market_question) + return result def is_predictable_restricted(self, market_question: str, time_restriction_up_to: datetime) -> bool: - return is_predictable(question=market_question) + (result, _) = is_predictable(question=market_question) + return result def predict(self, market_question: str) -> Prediction: try: diff --git a/evo_researcher/functions/evaluate_question.py b/evo_researcher/functions/evaluate_question.py index 94d66354..02e451c9 100644 --- a/evo_researcher/functions/evaluate_question.py +++ b/evo_researcher/functions/evaluate_question.py @@ -1,3 +1,6 @@ +import json +import os +from evo_researcher.autonolas.research import clean_completion_json from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate from evo_researcher.functions.cache import persistent_inmemory_cache @@ -19,7 +22,16 @@ Then, write down what is the future event of the question, what it reffers to and when that event will happen if the question contains it. -Then, give your final decision, write either "yes" or "no" about whether the question is answerable. +Then, give your final decision about whether the question is answerable. + +Return a JSON object with the following structure: + +{{ + "is_predictable": bool, + "reasoning": string +}} + +Output only the JSON object in your response. Do not include any other contents in your response. """ @@ -27,23 +39,21 @@ @persistent_inmemory_cache def is_predictable( question: str, - engine: str = "gpt-4-1106-preview", + engine: str = "gpt-4-0125-preview", prompt_template: str = QUESTION_EVALUATE_PROMPT, -) -> bool: + api_key: str | None = None +) -> tuple[bool, str]: """ Evaluate if the question is actually answerable. """ - llm = ChatOpenAI(model=engine, temperature=0.0) + + if api_key == None: + api_key = os.environ.get("OPENAI_API_KEY", "") + llm = ChatOpenAI(model=engine, temperature=0.0, api_key=api_key) prompt = ChatPromptTemplate.from_template(template=prompt_template) messages = prompt.format_messages(question=question) completion = llm(messages, max_tokens=256).content + response = json.loads(clean_completion_json(completion)) - if "yes" in completion.lower(): - is_predictable = True - elif "no" in completion.lower(): - is_predictable = False - else: - raise ValueError(f"Error in evaluate_question for `{question}`: {completion}") - - return is_predictable + return (response["is_predictable"], response["reasoning"]) diff --git a/evo_researcher/functions/grade_info.py b/evo_researcher/functions/grade_info.py index 573e75d0..4b17b33c 100644 --- a/evo_researcher/functions/grade_info.py +++ b/evo_researcher/functions/grade_info.py @@ -70,7 +70,7 @@ def grade_info(question: str, information: str) -> str: planning_prompt = ChatPromptTemplate.from_template(template=grading_planning_prompt_template) formatting_prompt = ChatPromptTemplate.from_template(template=grading_format_prompt_template) - llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0) + llm = ChatOpenAI(model="gpt-4-0125-preview", temperature=0) planning_chain = ( planning_prompt | diff --git a/evo_researcher/functions/rerank_results.py b/evo_researcher/functions/rerank_results.py index 8bfde4df..8ffa08c3 100644 --- a/evo_researcher/functions/rerank_results.py +++ b/evo_researcher/functions/rerank_results.py @@ -19,7 +19,7 @@ def rerank_results(results: list[str], goal: str) -> list[str]: rerank_results_chain = ( rerank_results_prompt | - ChatOpenAI(model="gpt-4-1106-preview") | + ChatOpenAI(model="gpt-4-0125-preview") | CommaSeparatedListOutputParser() ) diff --git a/evo_researcher/functions/research.py b/evo_researcher/functions/research.py index 2c4d3503..1ae6f560 100644 --- a/evo_researcher/functions/research.py +++ b/evo_researcher/functions/research.py @@ -12,7 +12,7 @@ def research( goal: str, use_summaries: bool, - model: str = "gpt-4-1106-preview", + model: str = "gpt-4-0125-preview", initial_subqueries_limit: int = 20, subqueries_limit: int = 4, scrape_content_split_chunk_size: int = 800, diff --git a/evo_researcher/main.py b/evo_researcher/main.py index 7e798e98..e44449cb 100644 --- a/evo_researcher/main.py +++ b/evo_researcher/main.py @@ -34,7 +34,7 @@ def research( start = time.time() with get_openai_callback() as cb: - report = evo_research(goal=prompt, use_summaries=False, model="gpt-4-1106-preview") + report = evo_research(goal=prompt, use_summaries=False, model="gpt-4-0125-preview") end = time.time() @@ -58,9 +58,9 @@ def predict(prompt: str, path: str | None = None) -> None: if path: report = read_text_file(path) else: - report = evo_research(goal=prompt, model="gpt-4-1106-preview", use_summaries=False) + report = evo_research(goal=prompt, model="gpt-4-0125-preview", use_summaries=False) - prediction = _make_prediction(market_question=prompt, additional_information=report, engine="gpt-4-1106-preview", temperature=0.0) + prediction = _make_prediction(market_question=prompt, additional_information=report, engine="gpt-4-0125-preview", temperature=0.0) end = time.time() diff --git a/scripts/benchmark.py b/scripts/benchmark.py index f5d8e975..16d7467c 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -83,7 +83,7 @@ def main( agent_name="evo_gpt-3.5-turbo-0125_tavilyrawcontent", use_tavily_raw_content=True, ), - # EvoAgent(model="gpt-4-1106-preview", max_workers=max_workers, agent_name="evo_gpt-4-1106-preview"), # Too expensive to be enabled by default. + # EvoAgent(model="gpt-4-0125-preview", max_workers=max_workers, agent_name="evo_gpt-4-0125-preview"), # Too expensive to be enabled by default. ], cache_path=cache_path, only_cached=only_cached, diff --git a/tests/test_evaluate_question.py b/tests/test_evaluate_question.py index 3458db39..8b2229bb 100644 --- a/tests/test_evaluate_question.py +++ b/tests/test_evaluate_question.py @@ -11,4 +11,5 @@ ("Did COVID-19 come from a laboratory?", False), ]) def test_evaluate_question(question: str, answerable: bool) -> None: - assert is_predictable(question=question) == answerable, f"Question is not evaluated correctly, see the completion: {is_predictable}" + (result, _) = is_predictable(question=question) + assert result == answerable, f"Question is not evaluated correctly, see the completion: {is_predictable}"