Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve runtimes but 'pushing up' common Case Statements into precomputed values #2580

Open
RobinL opened this issue Jan 8, 2025 · 1 comment

Comments

@RobinL
Copy link
Member

RobinL commented Jan 8, 2025

Currently, in Splink, comparison functions (e.g., cosine_sim) are evaluated multiple times within CASE statements during the predict() process. For example:

CASE
    WHEN cosine_sim(l, r) > 0.9 THEN 1
    WHEN cosine_sim(l, r) > 0.8 THEN 2
    WHEN cosine_sim(l, r) > 0.7 THEN 3
    ...
END

This results in repeated computation of cosine_sim(l, r) for each threshold, which impacts performance.

Proposed Enhancement:

Where functions are being repeated computed, automatically refactor the logic to precompute the comparison value once and reuse it in the CASE statement. For example:

WITH precomputed_cosine AS (
    SELECT cosine_sim(l, r) AS precompute_cosine_value, ...
)
SELECT
    CASE
        WHEN precompute_cosine_value > 0.9 THEN 1
        WHEN precompute_cosine_value > 0.8 THEN 2
        WHEN precompute_cosine_value > 0.7 THEN 3
        ...
    END

This would makes it easier to use more granular thresholds without additional overhead - and allow users to more easily approximate match weights being a linear function rather than bucketed (by simply using a large number of thresholds)

@RobinL
Copy link
Member Author

RobinL commented Jan 8, 2025

Some intiial experimentation with an LLM prompt:

Click to expand
# MAIN INSTRUCTION TO LLM
# CONSIDER THIS EXAMPLE AND GIVE ME AN EXAMPLE OF HOW TO USE SQLGLOT
# TO AUTOMATICALLY CONVERT A SQL STATEMENT IN INEFFICIENT FORMAT 1
# TO THE EFFICIENT FORMAT 2

import duckdb

con = duckdb.connect()

sql = """
CREATE TABLE names AS
SELECT * FROM (
    VALUES
    (1, 'John'),
    (2, 'Jonathan'),
    (3, 'Johnny'),
    (4, 'Michael'),
) AS t(id, first_name)
"""
con.execute(sql)

sql = """
CREATE VIEW joined_names AS
SELECT
    a.id as id_l,
    a.first_name as name_l,
    b.id as id_r,
    b.first_name as name_r
FROM names a
CROSS JOIN names b
WHERE a.id < b.id
"""
con.execute(sql)

# FORMAT 1: Inefficient version
sql = """
WITH similarity_calc AS (
    SELECT
        *,
        CASE
            WHEN levenshtein(name_l, name_r) < 2 THEN 0
            WHEN levenshtein(name_l, name_r) < 4 THEN 1
            WHEN levenshtein(name_l, name_r) < 6 THEN 2
            ELSE 3
        END as similarity_bin
    FROM joined_names
)
SELECT * FROM similarity_calc
ORDER BY distance
"""
con.sql(sql)

# FORMAT 2:  Efficient version
sql = """
WITH base_calc AS (
    SELECT
        *,
        levenshtein(name_l, name_r) as distance
    FROM joined_names
),
similarity_calc AS (
    SELECT
        *,
        CASE
            WHEN distance < 2 THEN 0
            WHEN distance < 4 THEN 1
            WHEN distance < 6 THEN 2
            ELSE 3
        END as similarity_bin
    FROM base_calc
)
SELECT * FROM similarity_calc
ORDER BY distance
"""
con.sql(sql)


# Suggested solution

import sqlglot
from sqlglot.optimizer import optimize

# Inefficient SQL query (Format 1)
sql_format_1 = """
WITH similarity_calc AS (
    SELECT
        *,
        CASE
            WHEN levenshtein(name_l, name_r) < 2 THEN 0
            WHEN levenshtein(name_l, name_r) < 4 THEN 1
            WHEN levenshtein(name_l, name_r) < 6 THEN 2
            ELSE 3
        END as similarity_bin
    FROM joined_names
)
SELECT * FROM similarity_calc
ORDER BY distance
"""

# 1. Parse the SQL query
expression = sqlglot.parse_one(sql_format_1)

# 2. Define a transformation that simplifies the query into Format 2
def my_custom_transformation(expression):
    """
    This specific transformation will only work for your specific example
    (it requires the original CTE to be named similarity_calc and
     also expects the distance column to be introduced there, as opposed to a base_calc CTE)
    """

    # Find the similarity_calc CTE
    similarity_calc_cte = expression.find(exp.CTE)
    if not similarity_calc_cte or similarity_calc_cte.alias != "similarity_calc":
        return expression

    # Capture the original select expression from the similarity_calc CTE
    select_expression = similarity_calc_cte.this

    if not isinstance(select_expression, exp.Select):
      return expression

    # Extract the CASE and distance function from the select expression in that CTE
    case_expression = select_expression.selects[-1]

    if not isinstance(case_expression, exp.Alias):
        return expression

    select_expression.selects.pop()
    select_expression.select(
        exp.alias_(exp.func("levenshtein", exp.column("name_l"), exp.column("name_r")), "distance"),
        copy=False,
    )


    similarity_calc_cte.set(
        "this",
        exp.select("*",
                exp.case()
                    .when(exp.column("distance").lt(2), 0)
                    .when(exp.column("distance").lt(4), 1)
                    .when(exp.column("distance").lt(6), 2)
                    .else_(3).as_("similarity_bin"),
                    copy=False
            )
        .from_("base_calc")
    )

    base_calc_cte = exp.CTE(
        alias=exp.to_identifier("base_calc"),
        this=exp.select("*",
                exp.func("levenshtein", exp.column("name_l"), exp.column("name_r")).as_("distance"),
                copy=False
            )
            .from_("joined_names"),
            copy=False
        )


    with_expression = expression.args.get("with")
    if with_expression:
        with_expression.set("expressions", [base_calc_cte, *with_expression.expressions])
    else:
        expression.set("with", exp.With(expressions=[base_calc_cte]))

    return expression

# 3. Apply the custom transformation
transformed_expression = expression.transform(my_custom_transformation)
# 4. Optimize the transformation
optimized_expression = optimize(transformed_expression)

# 5. Output the transformed SQL (Format 2)
sql_format_2 = optimized_expression.sql(pretty=True)
print(sql_format_2)
content_copy
Use code with caution.
Python

Explanati
script
import sqlglot
from sqlglot import exp

sql = """
SELECT
    CASE
        WHEN levenshtein(name_l, name_r) < 2 THEN 0
        WHEN levenshtein(name_l, name_r) < 4 THEN 1
        WHEN levenshtein(name_l, name_r) < 6 THEN 2
        WHEN jaro(name_l, name_r) < 0.9 THEN 2
        WHEN jaro(name_l, name_r) < 0.8 THEN 2
        ELSE 3
    END as similarity_bin
FROM joined_names
"""


expression = sqlglot.parse_one(sql)

function_counts = {}
for case_expr in expression.find_all(exp.Case):
    for when_expr in case_expr.find_all(exp.If):
        for func in when_expr.find_all(exp.Func):
            func_sql = func.sql()
            function_counts[func_sql] = function_counts.get(func_sql, 0) + 1


repeated_functions = {func for func, count in function_counts.items() if count > 1}


def transform(node):
    if isinstance(node, exp.Func) and node.sql() in repeated_functions:
        cleaned = "".join(c if c.isalnum() else "_" for c in node.sql().lower())
        return exp.Literal.string(cleaned)
    return node


transformed = expression.transform(transform)
print(transformed.sql())

But I came to the conclusion this is not the right approach - instead, we should detect common statements at the point we generate the CASE statement, because at this stage it's much simpler SQL code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant