Skip to content

Commit

Permalink
added polars-based accuracy summary to the benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ds-jakub-cierocki committed Jul 23, 2024
1 parent fbecc51 commit 5d4ff64
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions benchmark/dbally_benchmark/context_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
import polars as pl

import dbally
import asyncio
import typing
import json
import traceback
import os

import tqdm.asyncio
import sqlalchemy
import pydantic
from typing_extensions import TypeAlias
from copy import deepcopy
from sqlalchemy import create_engine
Expand All @@ -31,7 +31,8 @@
Candidate = Base.classes.candidates


class MyData(BaseCallerContext, pydantic.BaseModel):
@dataclass
class MyData(BaseCallerContext):
first_name: str
surname: str
position: str
Expand All @@ -41,7 +42,8 @@ class MyData(BaseCallerContext, pydantic.BaseModel):
country: str


class OpenPosition(BaseCallerContext, pydantic.BaseModel):
@dataclass
class OpenPosition(BaseCallerContext):
position: str
min_years_of_experience: int
graduated_from_university: str
Expand Down Expand Up @@ -130,7 +132,7 @@ def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.Col
return Candidate.name.startswith(first_name)


OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o']
OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-3.5-turbo-instruct', 'gpt-4-turbo', 'gpt-4o']


def setup_collection(model_name: OpenAILLMName) -> dbally.Collection:
Expand Down Expand Up @@ -224,10 +226,30 @@ async def main(config: BenchmarkConfig):

output_data[question]["answers"][llm_name].append(answer)

output_data_list = list(output_data.values())
df_out_raw = pl.DataFrame(list(output_data.values()))

df_out = (
df_out_raw
.unnest("answers")
.unpivot(
on=pl.selectors.starts_with("gpt"),
index=["question", "correct_answer", "context"],
variable_name="model",
value_name="answer"
)
.explode("answer")
.group_by(["context", "model"])
.agg([
(pl.col("correct_answer") == pl.col("answer")).mean().alias("frac_hits"),
(pl.col("correct_answer") == pl.col("answer")).sum().alias("n_hits"),
])
.sort(["model", "context"])
)

print(df_out)

with open(config.out_path, 'w') as file:
file.write(json.dumps(test_set, indent=2))
file.write(json.dumps(df_out_raw.to_dicts(), indent=2))


if __name__ == "__main__":
Expand Down

0 comments on commit 5d4ff64

Please sign in to comment.