Skip to content

Commit

Permalink
new text2sql execution scores and results in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
oktie committed Feb 7, 2025
1 parent 6d53232 commit 4390579
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 89 deletions.
87 changes: 78 additions & 9 deletions prepare/metrics/text2sql_execution_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from unitxt.catalog import add_to_catalog
from unitxt.metrics import ExecutionAccuracy
from unitxt.metrics import SQLExecutionAccuracy
from unitxt.test_utils.metrics import test_metric

metric = ExecutionAccuracy()
metric = SQLExecutionAccuracy()

predictions = [
"SELECT nme FROM employees WHERE department = 'Sales'",
"SELECT name FROM employees WHERE department = 'Sales'",
"SELECT name FROM employees WHERE department = 'Engineering'",
"SELECT id, name FROM employees WHERE department = 'Sales'",
"SELECT name FROM employees WHERE department = 'Non-Existent'",
] # Incorrect column name 'nme'
references = [["SELECT name FROM employees WHERE department = 'Sales';"]] * 2
references = [
["SELECT name FROM employees WHERE department = 'Sales';"],
["SELECT name FROM employees WHERE department = 'Sales';"],
["SELECT name FROM employees WHERE department = 'Sales';"],
["SELECT name FROM employees WHERE department = 'Sales';"],
["SELECT name FROM employees WHERE department = 'Non-Existent';"],
]
task_data = [
{
"db": {
Expand All @@ -26,29 +35,84 @@
},
}
}
] * 2
] * 5

instance_targets = [
{
"error_message": "Error executing SQL: no such column: nme",
"execution_accuracy": 0.0,
"gold_df_json": "",
"gold_error": 0.0,
"non_empty_execution_accuracy": 0.0,
"non_empty_gold_df": 0.0,
"predicted_df_json": "",
"predicted_error": 1.0,
"score": 0.0,
"score_name": "execution_accuracy",
"subset_non_empty_execution_result": 0.0,
},
{
"error_message": "",
"execution_accuracy": 1.0,
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
"gold_error": 1.0,
"non_empty_execution_accuracy": 1.0,
"non_empty_gold_df": 1.0,
"predicted_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
"predicted_error": 0.0,
"score": 1.0,
"score_name": "execution_accuracy",
"subset_non_empty_execution_result": 1.0,
},
{
"error_message": "None",
"execution_accuracy": 0.0,
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
"gold_error": 0.0,
"non_empty_execution_accuracy": 0.0,
"non_empty_gold_df": 0.0,
"predicted_df_json": '{"0":{"0":"Bob"}}',
"predicted_error": 0.0,
"score": 0.0,
"score_name": "execution_accuracy",
"subset_non_empty_execution_result": 0.0,
},
{
"error_message": "None",
"execution_accuracy": 0.0,
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
"gold_error": 0.0,
"non_empty_execution_accuracy": 0.0,
"non_empty_gold_df": 0.0,
"predicted_df_json": '{"0":{"0":1,"1":3},"1":{"0":"Alice","1":"Charlie"}}',
"predicted_error": 0.0,
"score": 0.0,
"score_name": "execution_accuracy",
"subset_non_empty_execution_result": 0.0,
},
{
"error_message": "",
"execution_accuracy": 1.0,
"gold_df_json": "{}",
"gold_error": 1.0,
"non_empty_execution_accuracy": 0.0,
"non_empty_gold_df": 0.0,
"predicted_df_json": "{}",
"predicted_error": 0.0,
"score": 1.0,
"score_name": "execution_accuracy",
"subset_non_empty_execution_result": 0.0,
},
]


global_target = {
"execution_accuracy": 0.5,
"execution_accuracy_ci_high": 1.0,
"execution_accuracy": 0.4,
"execution_accuracy_ci_high": 0.87,
"execution_accuracy_ci_low": 0.0,
"num_of_instances": 2,
"score": 0.5,
"score_ci_high": 1.0,
"num_of_instances": 5,
"score": 0.4,
"score_ci_high": 0.87,
"score_ci_low": 0.0,
"score_name": "execution_accuracy",
}
Expand All @@ -60,6 +124,11 @@
instance_targets=instance_targets,
global_target=global_target,
task_data=task_data,
score_keys_to_ignore=[
"predicted_sql_runtime",
"gold_sql_runtime",
"pred_to_gold_runtime_ratio",
],
)

add_to_catalog(metric, "metrics.text2sql.execution_accuracy", overwrite=True)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"__type__": "execution_accuracy"
"__type__": "sql_execution_accuracy"
}
31 changes: 20 additions & 11 deletions src/unitxt/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def execute_query_local(db_path: str, query: str) -> Any:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)
return cursor.fetchall()
return cursor.fetchall(), None
except sqlite3.Error as e:
logger.info(f"Error executing SQL: {e}")
return None
return None, f"Error executing SQL: {e}"
finally:
if conn:
conn.close()
Expand Down Expand Up @@ -178,10 +178,10 @@ def execute_query(self, query: str) -> Any:

try:
cursor.execute(query)
return cursor.fetchall()
return cursor.fetchall(), None
except sqlite3.Error as e:
logger.info(f"Error executing SQL: {e}")
return None
return None, f"Error executing SQL: {e}"
finally:
conn.close()

Expand All @@ -196,7 +196,7 @@ def execute_query_remote(
max_retries: int = 3,
retry_delay: int = 5, # seconds
timeout: int = 30, # seconds
) -> Optional[dict]:
) -> (Optional[dict], str):
"""Executes a query against the remote database, with retries for certain exceptions."""
headers = {
"Content-Type": "application/json",
Expand All @@ -214,7 +214,7 @@ def execute_query_remote(
timeout=timeout,
)
response.raise_for_status()
return response.json()
return response.json(), None

except retryable_exceptions as e:
retries += 1
Expand All @@ -225,7 +225,10 @@ def execute_query_remote(
time.sleep(retry_delay)
else:
logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
return None
return (
None,
f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
)

except requests.exceptions.HTTPError as e:
if e.response.status_code >= 500:
Expand All @@ -239,16 +242,22 @@ def execute_query_remote(
logger.error(
f"Max retries ({max_retries}) exceeded for query: {query}"
)
return None
return (
None,
f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
)
else:
logger.error(f"HTTP Error on attempt {retries}: {e}")
return None
return (
None,
f"HTTP Error on attempt {retries}: {e}",
)

except Exception as e:
logger.error(f"Unexpected error on attempt {retries}: {e}")
return None
return (None, f"Unexpected error on attempt {retries}: {e}")

return None
return None, "Unknown Error in SQL execution"


class RemoteDatabaseConnector(DatabaseConnector):
Expand Down
Loading

0 comments on commit 4390579

Please sign in to comment.