Skip to content

Commit

Permalink
Translate to NULLIF
Browse files Browse the repository at this point in the history
Closes #31682
  • Loading branch information
roji committed Dec 18, 2024
1 parent c53bbac commit 03fd843
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 73 deletions.
2 changes: 2 additions & 0 deletions EFCore.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ The .NET Foundation licenses this file to you under the MIT license.
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery_0027s/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=transactionality/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalesce/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalescing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unconfigured/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unignore/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=fixup/@EntryIndexedValue">True</s:Boolean>
Expand Down
35 changes: 31 additions & 4 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,37 @@ public virtual SqlExpression Condition(SqlExpression test, SqlExpression ifTrue,
{
var typeMapping = ExpressionExtensions.InferTypeMapping(ifTrue, ifFalse);

return new SqlConditionalExpression(
ApplyTypeMapping(test, _boolTypeMapping),
ApplyTypeMapping(ifTrue, typeMapping),
ApplyTypeMapping(ifFalse, typeMapping));
test = ApplyTypeMapping(test, _boolTypeMapping);
ifTrue = ApplyTypeMapping(ifTrue, typeMapping);
ifFalse = ApplyTypeMapping(ifFalse, typeMapping);

// Simplify:
// a == b ? b : a -> a
// a != b ? a : b -> a
if (test is SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary)
{
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal ? (ifTrue, ifFalse) : (ifFalse, ifTrue);

// a == b ? b : a -> a
if (left.Equals(ifNotEqual) && right.Equals(ifEqual))
{
return left;
}

// b == a ? b : a -> a
if (right.Equals(ifNotEqual) && left.Equals(ifEqual))
{
return right;
}
}

return new SqlConditionalExpression(test, ifTrue, ifFalse);
}

/// <summary>
Expand Down
52 changes: 52 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,58 @@ public virtual SqlExpression Case(
elseResult = lastCase.ElseResult;
}

// Simplify:
// a == b ? b : a -> a
// a != b ? a : b -> a
// And lift:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
if (operand is null
&& typeMappedWhenClauses is
[
{
Test: SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary,
Result: var result
}
])
{
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal
? (result, elseResult ?? Constant(null, result.Type, result.TypeMapping))
: (elseResult ?? Constant(null, result.Type, result.TypeMapping), result);

if (left.Equals(ifNotEqual))
{
switch (ifEqual)
{
// a == b ? b : a -> a
case var _ when ifEqual.Equals(right):
return left;
// a == b ? null : a -> NULLIF(a, b)
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
}
}

if (right.Equals(ifNotEqual))
{
switch (ifEqual)
{
// b == a ? b : a -> a
case var _ when ifEqual.Equals(left):
return right;
// b == a ? null : a -> NULLIF(a, b)
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
}
}
}

return existingExpression is CaseExpression expr
&& operand == expr.Operand
&& typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,94 @@ public OperatorTranslationsCosmosTest(BasicTypesQueryCosmosFixture fixture, ITes
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

#region Conditional

public override Task Conditional_simplifiable_equality(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_simplifiable_equality(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (c["Int"] > 1)
""");
});

public override Task Conditional_simplifiable_inequality(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_simplifiable_inequality(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (c["Int"] > 1)
""");
});

public override Task Conditional_uncoalesce_with_equality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_equality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_unequality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_unequality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
""");
});

public override Task Conditional_uncoalesce_with_inequality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_inequality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
""");
});

#endregion Conditional

#region Bitwise

public override Task Bitwise_or(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ public virtual Task Where_equal_with_conditional(bool async)
ss => ss.Set<NullSemanticsEntity1>().Where(
e => (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)
: e.NullableStringC)
== e.NullableStringC).Select(e => e.Id));

[ConditionalTheory]
Expand All @@ -765,7 +765,7 @@ public virtual Task Where_not_equal_with_conditional(bool async)
e => e.NullableStringC
!= (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)).Select(e => e.Id));
: e.NullableStringC)).Select(e => e.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,58 @@ namespace Microsoft.EntityFrameworkCore.Query.Translations;
public abstract class OperatorTranslationsTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
where TFixture : BasicTypesQueryFixtureBase, new()
{
// See also operators precedence tests in OperatorsQueryTestBase

#region Conditional

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_equality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == 9 ? 9 : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_inequality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != 8 ? x.Int : 8) > 1));

// In relational providers, x == a ? null : x ("un-coalescing conditional") is translated to SQL NULLIF

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_unequality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_inequality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));

#endregion Conditional

#region Bitwise
#pragma warning disable CS0675 // Bitwise-or operator used on a sign-extended operand

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,7 @@ public override async Task Select_null_propagation_works_for_multiple_navigation

AssertSql(
"""
SELECT CASE
WHEN [c].[Name] IS NOT NULL THEN [c].[Name]
END
SELECT [c].[Name]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
LEFT JOIN [Tags] AS [t0] ON ([g].[Nickname] = [t0].[GearNickName] OR ([g].[Nickname] IS NULL AND [t0].[GearNickName] IS NULL)) AND ([g].[SquadId] = [t0].[GearSquadId] OR ([g].[SquadId] IS NULL AND [t0].[GearSquadId] IS NULL))
Expand Down Expand Up @@ -1981,10 +1979,7 @@ public override async Task Optional_navigation_type_compensation_works_with_pred
SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[IssueDate], [t].[Note]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
WHERE CASE
WHEN [g].[HasSoulPatch] = CAST(1 AS bit) THEN CAST(1 AS bit)
ELSE [g].[HasSoulPatch]
END = CAST(0 AS bit)
WHERE [g].[HasSoulPatch] = CAST(0 AS bit)
""");
}

Expand All @@ -1997,10 +1992,7 @@ public override async Task Optional_navigation_type_compensation_works_with_pred
SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[IssueDate], [t].[Note]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
WHERE CASE
WHEN [g].[HasSoulPatch] = CAST(0 AS bit) THEN CAST(0 AS bit)
ELSE [g].[HasSoulPatch]
END = CAST(0 AS bit)
WHERE [g].[HasSoulPatch] = CAST(0 AS bit)
""");
}

Expand Down Expand Up @@ -3057,9 +3049,7 @@ public override async Task Select_null_conditional_with_inheritance(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[CommanderName] IS NOT NULL THEN [f].[CommanderName]
END
SELECT [f].[CommanderName]
FROM [Factions] AS [f]
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ public NorthwindFunctionsQuerySqlServer160Test(Fixture160 fixture, ITestOutputHe
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());

public override async Task Client_evaluation_of_uncorrelated_method_call(bool async)
{
await base.Client_evaluation_of_uncorrelated_method_call(async);

AssertSql(
"""
SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[UnitPrice] < 7.0 AND 10 < [o].[ProductID]
""");
}

public override async Task Sum_over_round_works_correctly_in_projection(bool async)
{
await base.Sum_over_round_works_correctly_in_projection(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2207,10 +2207,10 @@ SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END = [e].[NullableStringC] OR (CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NULL AND [e].[NullableStringC] IS NULL)
""");
}
Expand All @@ -2225,13 +2225,13 @@ SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE ([e].[NullableStringC] <> CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END OR [e].[NullableStringC] IS NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NULL) AND ([e].[NullableStringC] IS NOT NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NOT NULL)
""");
}
Expand Down
Loading

0 comments on commit 03fd843

Please sign in to comment.