-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve runtimes but 'pushing up' common Case Statements into precomputed values #2580
Comments
Some intiial experimentation with an LLM prompt: Click to expand# MAIN INSTRUCTION TO LLM
# CONSIDER THIS EXAMPLE AND GIVE ME AN EXAMPLE OF HOW TO USE SQLGLOT
# TO AUTOMATICALLY CONVERT A SQL STATEMENT IN INEFFICIENT FORMAT 1
# TO THE EFFICIENT FORMAT 2
import duckdb
con = duckdb.connect()
sql = """
CREATE TABLE names AS
SELECT * FROM (
VALUES
(1, 'John'),
(2, 'Jonathan'),
(3, 'Johnny'),
(4, 'Michael'),
) AS t(id, first_name)
"""
con.execute(sql)
sql = """
CREATE VIEW joined_names AS
SELECT
a.id as id_l,
a.first_name as name_l,
b.id as id_r,
b.first_name as name_r
FROM names a
CROSS JOIN names b
WHERE a.id < b.id
"""
con.execute(sql)
# FORMAT 1: Inefficient version
sql = """
WITH similarity_calc AS (
SELECT
*,
CASE
WHEN levenshtein(name_l, name_r) < 2 THEN 0
WHEN levenshtein(name_l, name_r) < 4 THEN 1
WHEN levenshtein(name_l, name_r) < 6 THEN 2
ELSE 3
END as similarity_bin
FROM joined_names
)
SELECT * FROM similarity_calc
ORDER BY distance
"""
con.sql(sql)
# FORMAT 2: Efficient version
sql = """
WITH base_calc AS (
SELECT
*,
levenshtein(name_l, name_r) as distance
FROM joined_names
),
similarity_calc AS (
SELECT
*,
CASE
WHEN distance < 2 THEN 0
WHEN distance < 4 THEN 1
WHEN distance < 6 THEN 2
ELSE 3
END as similarity_bin
FROM base_calc
)
SELECT * FROM similarity_calc
ORDER BY distance
"""
con.sql(sql)
# Suggested solution
import sqlglot
from sqlglot.optimizer import optimize
# Inefficient SQL query (Format 1)
sql_format_1 = """
WITH similarity_calc AS (
SELECT
*,
CASE
WHEN levenshtein(name_l, name_r) < 2 THEN 0
WHEN levenshtein(name_l, name_r) < 4 THEN 1
WHEN levenshtein(name_l, name_r) < 6 THEN 2
ELSE 3
END as similarity_bin
FROM joined_names
)
SELECT * FROM similarity_calc
ORDER BY distance
"""
# 1. Parse the SQL query
expression = sqlglot.parse_one(sql_format_1)
# 2. Define a transformation that simplifies the query into Format 2
def my_custom_transformation(expression):
"""
This specific transformation will only work for your specific example
(it requires the original CTE to be named similarity_calc and
also expects the distance column to be introduced there, as opposed to a base_calc CTE)
"""
# Find the similarity_calc CTE
similarity_calc_cte = expression.find(exp.CTE)
if not similarity_calc_cte or similarity_calc_cte.alias != "similarity_calc":
return expression
# Capture the original select expression from the similarity_calc CTE
select_expression = similarity_calc_cte.this
if not isinstance(select_expression, exp.Select):
return expression
# Extract the CASE and distance function from the select expression in that CTE
case_expression = select_expression.selects[-1]
if not isinstance(case_expression, exp.Alias):
return expression
select_expression.selects.pop()
select_expression.select(
exp.alias_(exp.func("levenshtein", exp.column("name_l"), exp.column("name_r")), "distance"),
copy=False,
)
similarity_calc_cte.set(
"this",
exp.select("*",
exp.case()
.when(exp.column("distance").lt(2), 0)
.when(exp.column("distance").lt(4), 1)
.when(exp.column("distance").lt(6), 2)
.else_(3).as_("similarity_bin"),
copy=False
)
.from_("base_calc")
)
base_calc_cte = exp.CTE(
alias=exp.to_identifier("base_calc"),
this=exp.select("*",
exp.func("levenshtein", exp.column("name_l"), exp.column("name_r")).as_("distance"),
copy=False
)
.from_("joined_names"),
copy=False
)
with_expression = expression.args.get("with")
if with_expression:
with_expression.set("expressions", [base_calc_cte, *with_expression.expressions])
else:
expression.set("with", exp.With(expressions=[base_calc_cte]))
return expression
# 3. Apply the custom transformation
transformed_expression = expression.transform(my_custom_transformation)
# 4. Optimize the transformation
optimized_expression = optimize(transformed_expression)
# 5. Output the transformed SQL (Format 2)
sql_format_2 = optimized_expression.sql(pretty=True)
print(sql_format_2)
content_copy
Use code with caution.
Python
Explanati scriptimport sqlglot
from sqlglot import exp
sql = """
SELECT
CASE
WHEN levenshtein(name_l, name_r) < 2 THEN 0
WHEN levenshtein(name_l, name_r) < 4 THEN 1
WHEN levenshtein(name_l, name_r) < 6 THEN 2
WHEN jaro(name_l, name_r) < 0.9 THEN 2
WHEN jaro(name_l, name_r) < 0.8 THEN 2
ELSE 3
END as similarity_bin
FROM joined_names
"""
expression = sqlglot.parse_one(sql)
function_counts = {}
for case_expr in expression.find_all(exp.Case):
for when_expr in case_expr.find_all(exp.If):
for func in when_expr.find_all(exp.Func):
func_sql = func.sql()
function_counts[func_sql] = function_counts.get(func_sql, 0) + 1
repeated_functions = {func for func, count in function_counts.items() if count > 1}
def transform(node):
if isinstance(node, exp.Func) and node.sql() in repeated_functions:
cleaned = "".join(c if c.isalnum() else "_" for c in node.sql().lower())
return exp.Literal.string(cleaned)
return node
transformed = expression.transform(transform)
print(transformed.sql()) But I came to the conclusion this is not the right approach - instead, we should detect common statements at the point we generate the CASE statement, because at this stage it's much simpler SQL code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently, in Splink, comparison functions (e.g.,
cosine_sim
) are evaluated multiple times withinCASE
statements during thepredict()
process. For example:This results in repeated computation of
cosine_sim(l, r)
for each threshold, which impacts performance.Proposed Enhancement:
Where functions are being repeated computed, automatically refactor the logic to precompute the comparison value once and reuse it in the
CASE
statement. For example:This would makes it easier to use more granular thresholds without additional overhead - and allow users to more easily approximate match weights being a linear function rather than bucketed (by simply using a large number of thresholds)
The text was updated successfully, but these errors were encountered: