Skip to content

Commit

Permalink
Detailed agent errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii committed Nov 18, 2024
1 parent 6069d05 commit 4c72086
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
17 changes: 12 additions & 5 deletions prediction_prophet/benchmark/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from prediction_prophet.autonolas.research import make_prediction, get_urls_from_queries
from prediction_prophet.autonolas.research import research as research_autonolas
from prediction_prophet.functions.rephrase_question import rephrase_question
from prediction_prophet.functions.research import Research, research as prophet_research
from prediction_prophet.functions.research import NoResulsFoundError, NotEnoughScrapedSitesError, Research, research as prophet_research
from prediction_prophet.functions.search import search
from prediction_prophet.functions.utils import url_is_older_than
from prediction_prophet.models.WebSearchResult import WebSearchResult
Expand Down Expand Up @@ -71,10 +71,12 @@ def __init__(
temperature: float = 0.0,
agent_name: str = "question-only",
max_workers: t.Optional[int] = None,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
):
super().__init__(agent_name=agent_name, max_workers=max_workers)
self.model: str = model
self.temperature = temperature
self.logger = logger

def predict(
self, market_question: str
Expand All @@ -87,7 +89,7 @@ def predict(
temperature=self.temperature,
)
except ValueError as e:
print(f"Error in QuestionOnlyAgent's predict: {e}")
self.logger.error(f"Error in QuestionOnlyAgent's predict: {e}")
return Prediction()

def predict_restricted(
Expand All @@ -104,11 +106,13 @@ def __init__(
agent_name: str = "olas",
max_workers: t.Optional[int] = None,
embedding_model: EmbeddingModel = EmbeddingModel.spacy,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
):
super().__init__(agent_name=agent_name, max_workers=max_workers)
self.model: str = model
self.temperature = temperature
self.embedding_model = embedding_model
self.logger = logger

def is_predictable(self, market_question: str) -> bool:
result = is_predictable_binary(question=market_question)
Expand All @@ -135,7 +139,7 @@ def predict(self, market_question: str) -> Prediction:
temperature=self.temperature,
)
except ValueError as e:
print(f"Error in OlasAgent's predict: {e}")
self.logger.error(f"Error in OlasAgent's predict: {e}")
return Prediction()

def predict_restricted(
Expand Down Expand Up @@ -215,8 +219,11 @@ def predict(self, market_question: str) -> Prediction:
temperature=self.prediction_temperature,
include_reasoning=self.include_reasoning,
)
except ValueError as e:
print(f"Error in PredictionProphet's predict: {e}")
except (NoResulsFoundError, NotEnoughScrapedSitesError) as e:
self.logger.warning(f"Problem in PredictionProphet's predict: {e}")
return Prediction()
except Exception as e:
self.logger.error(f"Error in PredictionProphet's predict: {e}")
return Prediction()

def predict_restricted(
Expand Down
12 changes: 10 additions & 2 deletions prediction_prophet/functions/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class Research(BaseModel):
websites_scraped: list[WebScrapeResult]


class NoResulsFoundError(ValueError):
pass


class NotEnoughScrapedSitesError(ValueError):
pass


@observe()
def research(
goal: str,
Expand Down Expand Up @@ -72,7 +80,7 @@ def research(
)

if not search_results_with_queries:
raise ValueError(f"No search results found for the goal {goal}.")
raise NoResulsFoundError(f"No search results found for the goal {goal}.")

scrape_args = [result for (_, result) in search_results_with_queries]
websites_to_scrape = set(result.url for result in scrape_args)
Expand All @@ -92,7 +100,7 @@ def research(
unique_scraped_websites = set([result.url for result in scraped])
if len(scraped) < min_scraped_sites:
# Get urls that were not scraped
raise ValueError(
raise NotEnoughScrapedSitesError(
f"Only successfully scraped content from "
f"{len(unique_scraped_websites)} websites, out of a possible "
f"{len(websites_to_scrape)} websites, which is less than the "
Expand Down

0 comments on commit 4c72086

Please sign in to comment.