Skip to content

Commit

Permalink
Add langfuse integration (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Aug 21, 2024
1 parent 92eed91 commit 7fb6bd6
Show file tree
Hide file tree
Showing 9 changed files with 587 additions and 415 deletions.
916 changes: 528 additions & 388 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions prediction_prophet/autonolas/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from prediction_prophet.functions.parallelism import par_map
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from langfuse.decorators import langfuse_context
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

load_dotenv()

Expand Down Expand Up @@ -1036,6 +1038,7 @@ def join_and_group_sentences(
return final_output


@observe()
def fetch_additional_information(
event_question: str,
max_add_words: int,
Expand Down Expand Up @@ -1088,7 +1091,7 @@ def fetch_additional_information(
) |
StrOutputParser()
)
response = research_chain.invoke({})
response = research_chain.invoke({}, config=get_langfuse_langchain_config())

# Parse the response content
try:
Expand Down Expand Up @@ -1117,6 +1120,7 @@ def fetch_additional_information(
return additional_informations


@observe()
def research(
prompt: str,
max_tokens: int | None = None,
Expand Down Expand Up @@ -1165,19 +1169,20 @@ def research(
enc=enc,
)

# Spacy loads ~500MB into memory, make it free it with force.
# Spacy loads ~500MB into memory. Free it with force.
del nlp
gc.collect()

return additional_information


@observe()
def make_prediction(
prompt: str,
additional_information: str,
temperature: float = 0.7,
engine: str = "gpt-3.5-turbo-0125",
api_key: SecretStr | None = None
api_key: SecretStr | None = None,
) -> Prediction:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
Expand All @@ -1189,7 +1194,7 @@ def make_prediction(

llm = ChatOpenAI(model=engine, temperature=temperature, api_key=secretstr_to_v1_secretstr(api_key))
formatted_messages = prediction_prompt.format_messages(user_prompt=prompt, additional_information=additional_information, timestamp=formatted_time_utc)
generation = llm.generate([formatted_messages], logprobs=True, top_logprobs=5)
generation = llm.generate([formatted_messages], logprobs=True, top_logprobs=5, callbacks=[langfuse_context.get_current_langchain_handler()])

completion = generation.generations[0][0].text

Expand Down
18 changes: 10 additions & 8 deletions prediction_prophet/benchmark/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from prediction_prophet.autonolas.research import research as research_autonolas
from prediction_prophet.functions.evaluate_question import is_predictable
from prediction_prophet.functions.rephrase_question import rephrase_question
from prediction_prophet.functions.research import research as prophet_research
from prediction_prophet.functions.research import Research, research as prophet_research
from prediction_prophet.functions.search import search
from prediction_prophet.functions.utils import url_is_older_than
from prediction_prophet.models.WebSearchResult import WebSearchResult
Expand All @@ -25,18 +25,20 @@
)
from pydantic.types import SecretStr
from prediction_prophet.autonolas.research import Prediction as LLMCompletionPredictionDict
from prediction_market_agent_tooling.tools.langfuse_ import observe
from prediction_market_agent_tooling.tools.tavily_storage.tavily_models import TavilyStorage

if t.TYPE_CHECKING:
from loguru import Logger


@observe()
def _make_prediction(
market_question: str,
additional_information: str,
engine: str,
temperature: float,
api_key: SecretStr | None = None
api_key: SecretStr | None = None,
) -> Prediction:
"""
We prompt model to output a simple flat JSON and convert it to a more structured pydantic model here.
Expand All @@ -46,7 +48,7 @@ def _make_prediction(
additional_information=additional_information,
engine=engine,
temperature=temperature,
api_key=api_key
api_key=api_key,
)
return completion_prediction_json_to_pydantic_model(
prediction
Expand Down Expand Up @@ -178,7 +180,7 @@ def is_predictable_restricted(self, market_question: str, time_restriction_up_to
(result, _) = is_predictable(question=market_question)
return result

def research(self, market_question: str) -> str:
def research(self, market_question: str) -> Research:
return prophet_research(
goal=market_question,
model=self.model,
Expand All @@ -187,16 +189,16 @@ def research(self, market_question: str) -> str:
tavily_storage=self.tavily_storage,
logger=self.logger,
)

def predict(self, market_question: str) -> Prediction:
try:
report = self.research(market_question)
research = self.research(market_question)
return _make_prediction(
market_question=market_question,
additional_information=report,
additional_information=research.report,
engine=self.model,
temperature=self.temperature,
)
)
except ValueError as e:
print(f"Error in PredictionProphet's predict: {e}")
return Prediction()
Expand Down
4 changes: 3 additions & 1 deletion prediction_prophet/functions/generate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe


subquery_generation_template = """
Expand All @@ -15,6 +16,7 @@
Limit your searches to {search_limit}.
"""
@observe()
def generate_subqueries(query: str, limit: int, model: str, api_key: SecretStr | None = None) -> list[str]:
if limit == 0:
return [query]
Expand All @@ -33,6 +35,6 @@ def generate_subqueries(query: str, limit: int, model: str, api_key: SecretStr |
subqueries = subquery_generation_chain.invoke({
"query": query,
"search_limit": limit
})
}, config=get_langfuse_langchain_config())

return [query] + [subquery.strip('\"') for subquery in subqueries]
8 changes: 5 additions & 3 deletions prediction_prophet/functions/prepare_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr

from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

@persistent_inmemory_cache
@observe()
def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | None = None, trim_content_to_tokens: t.Optional[int] = None) -> str:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
Expand All @@ -36,11 +37,12 @@ def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | No
response: str = research_evaluation_chain.invoke({
"goal": goal,
"content": content
})
}, config=get_langfuse_langchain_config())

return response


@observe()
def prepare_report(goal: str, scraped: list[str], model: str, api_key: SecretStr | None = None) -> str:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
Expand Down Expand Up @@ -77,6 +79,6 @@ def prepare_report(goal: str, scraped: list[str], model: str, api_key: SecretStr
response: str = research_evaluation_chain.invoke({
"search_results": scraped,
"goal": goal
})
}, config=get_langfuse_langchain_config())

return response
4 changes: 3 additions & 1 deletion prediction_prophet/functions/rerank_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

rerank_queries_template = """
I will present you with a list of queries to search the web for, for answers to the question: {goal}.
Expand All @@ -16,6 +17,7 @@
Queries: {queries}
"""
@observe()
def rerank_subqueries(queries: list[str], goal: str, model: str, api_key: SecretStr | None = None) -> list[str]:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
Expand All @@ -31,6 +33,6 @@ def rerank_subqueries(queries: list[str], goal: str, model: str, api_key: Secret
responses: str = rerank_results_chain.invoke({
"goal": goal,
"queries": "\n---query---\n".join(queries)
})
}, config=get_langfuse_langchain_config())

return responses.split(",")
32 changes: 25 additions & 7 deletions prediction_prophet/functions/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,23 @@
from prediction_prophet.functions.scrape_results import scrape_results
from prediction_prophet.functions.search import search
from pydantic.types import SecretStr
from pydantic import BaseModel
from prediction_market_agent_tooling.tools.langfuse_ import observe
from prediction_market_agent_tooling.tools.tavily_storage.tavily_models import TavilyStorage

if t.TYPE_CHECKING:
from loguru import Logger


class Research(BaseModel):
report: str
all_queries: list[str]
reranked_queries: list[str]
websites_to_scrape: list[str]
websites_scraped: list[WebScrapeResult]


@observe()
def research(
goal: str,
use_summaries: bool,
Expand All @@ -31,7 +43,7 @@ def research(
tavily_api_key: SecretStr | None = None,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
tavily_storage: TavilyStorage | None = None,
) -> str:
) -> Research:
# Validate args
if min_scraped_sites > max_results_per_search * subqueries_limit:
raise ValueError(
Expand All @@ -41,13 +53,13 @@ def research(
)

logger.info("Started subqueries generation")
queries = generate_subqueries(query=goal, limit=initial_subqueries_limit, model=model, api_key=openai_api_key)
all_queries = generate_subqueries(query=goal, limit=initial_subqueries_limit, model=model, api_key=openai_api_key)

stringified_queries = '\n- ' + '\n- '.join(queries)
stringified_queries = '\n- ' + '\n- '.join(all_queries)
logger.info(f"Generated subqueries: {stringified_queries}")

logger.info("Started subqueries reranking")
queries = rerank_subqueries(queries=queries, goal=goal, model=model, api_key=openai_api_key)[:subqueries_limit] if initial_subqueries_limit > subqueries_limit else queries
queries = rerank_subqueries(queries=all_queries, goal=goal, model=model, api_key=openai_api_key)[:subqueries_limit] if initial_subqueries_limit > subqueries_limit else all_queries

stringified_queries = '\n- ' + '\n- '.join(queries)
logger.info(f"Reranked subqueries. Will use top {subqueries_limit}: {stringified_queries}")
Expand All @@ -65,7 +77,7 @@ def research(
raise ValueError(f"No search results found for the goal {goal}.")

scrape_args = [result for (_, result) in search_results_with_queries]
websites_to_scrape = set([result.url for result in scrape_args])
websites_to_scrape = set(result.url for result in scrape_args)

stringified_websites = '\n- ' + '\n- '.join(websites_to_scrape)
logger.info(f"Found the following relevant results: {stringified_websites}")
Expand Down Expand Up @@ -126,7 +138,7 @@ def research(
content,
"gpt-3.5-turbo-0125",
api_key=openai_api_key,
trim_content_to_tokens=14_000
trim_content_to_tokens=14_000,
)
for content in url_to_content_deemed_most_useful.values()
]
Expand All @@ -137,4 +149,10 @@ def research(
logger.info(f"Report prepared")
logger.info(report)

return report
return Research(
all_queries=all_queries,
reranked_queries=queries,
report=report,
websites_to_scrape=list(websites_to_scrape),
websites_scraped=scraped,
)
5 changes: 3 additions & 2 deletions prediction_prophet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def research(
start = time.time()

with get_openai_callback() as cb:
report = prophet_research(goal=prompt, use_summaries=False, model="gpt-4-0125-preview")
research = prophet_research(goal=prompt, use_summaries=False, model="gpt-4-0125-preview")

report = research.report
end = time.time()

if file:
Expand All @@ -66,7 +67,7 @@ def predict(prompt: str, path: str | None = None) -> None:
else:
logger = logging.getLogger("research")
logger.setLevel(logging.INFO)
report = prophet_research(goal=prompt, model="gpt-4-0125-preview", use_summaries=False, logger=logger)
report = prophet_research(goal=prompt, model="gpt-4-0125-preview", use_summaries=False, logger=logger).report

prediction = make_debated_prediction(prompt=prompt, additional_information=report)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ scikit-learn = "^1.4.0"
typer = ">=0.9.0,<1.0.0"
types-requests = "^2.31.0.20240125"
types-python-dateutil = "^2.9.0"
prediction-market-agent-tooling = { version = ">=0.48.0,<1", extras = ["langchain", "google"] }
prediction-market-agent-tooling = { version = ">=0.48.3,<1", extras = ["langchain", "google"] }
langchain-community = "^0.2.6"
memory-profiler = "^0.61.0"
matplotlib = "^3.8.3"
Expand Down

0 comments on commit 7fb6bd6

Please sign in to comment.