Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate to NULLIF #35327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions EFCore.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ The .NET Foundation licenses this file to you under the MIT license.
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery_0027s/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=transactionality/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalesce/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalescing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unconfigured/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unignore/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=fixup/@EntryIndexedValue">True</s:Boolean>
Expand Down
35 changes: 31 additions & 4 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,37 @@ public virtual SqlExpression Condition(SqlExpression test, SqlExpression ifTrue,
{
var typeMapping = ExpressionExtensions.InferTypeMapping(ifTrue, ifFalse);

return new SqlConditionalExpression(
ApplyTypeMapping(test, _boolTypeMapping),
ApplyTypeMapping(ifTrue, typeMapping),
ApplyTypeMapping(ifFalse, typeMapping));
test = ApplyTypeMapping(test, _boolTypeMapping);
ifTrue = ApplyTypeMapping(ifTrue, typeMapping);
ifFalse = ApplyTypeMapping(ifFalse, typeMapping);

// Simplify:
// a == b ? b : a -> a
// a != b ? a : b -> a
if (test is SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary)
{
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal ? (ifTrue, ifFalse) : (ifFalse, ifTrue);

// a == b ? b : a -> a
if (left.Equals(ifNotEqual) && right.Equals(ifEqual))
{
return left;
}

// b == a ? b : a -> a
if (right.Equals(ifNotEqual) && left.Equals(ifEqual))
{
return right;
}
}

return new SqlConditionalExpression(test, ifTrue, ifFalse);
}

/// <summary>
Expand Down
52 changes: 52 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,58 @@ public virtual SqlExpression Case(
elseResult = lastCase.ElseResult;
}

// Simplify:
// a == b ? b : a -> a
// a != b ? a : b -> a
// And lift:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
if (operand is null
&& typeMappedWhenClauses is
[
{
Test: SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary,
Result: var result
}
])
{
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal
? (result, elseResult ?? Constant(null, result.Type, result.TypeMapping))
: (elseResult ?? Constant(null, result.Type, result.TypeMapping), result);

if (left.Equals(ifNotEqual))
{
switch (ifEqual)
{
// a == b ? b : a -> a
case var _ when ifEqual.Equals(right):
return left;
// a == b ? null : a -> NULLIF(a, b)
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
}
}

if (right.Equals(ifNotEqual))
{
switch (ifEqual)
{
// b == a ? b : a -> a
case var _ when ifEqual.Equals(left):
return right;
// b == a ? null : a -> NULLIF(a, b)
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
}
}
}

return existingExpression is CaseExpression expr
&& operand == expr.Operand
&& typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,94 @@ public OperatorTranslationsCosmosTest(BasicTypesQueryCosmosFixture fixture, ITes
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

#region Conditional

public override Task Conditional_simplifiable_equality(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_simplifiable_equality(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (c["Int"] > 1)
""");
});

public override Task Conditional_simplifiable_inequality(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_simplifiable_inequality(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (c["Int"] > 1)
""");
});

public override Task Conditional_uncoalesce_with_equality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_equality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_equality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
""");
});

public override Task Conditional_uncoalesce_with_unequality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_unequality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
""");
});

public override Task Conditional_uncoalesce_with_inequality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Conditional_uncoalesce_with_inequality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
""");
});

#endregion Conditional

#region Bitwise

public override Task Bitwise_or(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ public virtual Task Where_equal_with_conditional(bool async)
ss => ss.Set<NullSemanticsEntity1>().Where(
e => (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)
: e.NullableStringC)
== e.NullableStringC).Select(e => e.Id));

[ConditionalTheory]
Expand All @@ -765,7 +765,7 @@ public virtual Task Where_not_equal_with_conditional(bool async)
e => e.NullableStringC
!= (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)).Select(e => e.Id));
: e.NullableStringC)).Select(e => e.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,58 @@ namespace Microsoft.EntityFrameworkCore.Query.Translations;
public abstract class OperatorTranslationsTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
where TFixture : BasicTypesQueryFixtureBase, new()
{
// See also operators precedence tests in OperatorsQueryTestBase

#region Conditional

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_equality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == 9 ? 9 : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_inequality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != 8 ? x.Int : 8) > 1));

// In relational providers, x == a ? null : x ("un-coalescing conditional") is translated to SQL NULLIF

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_equality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_unequality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_uncoalesce_with_inequality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));

#endregion Conditional

#region Bitwise
#pragma warning disable CS0675 // Bitwise-or operator used on a sign-extended operand

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,7 @@ public override async Task Select_null_propagation_works_for_multiple_navigation

AssertSql(
"""
SELECT CASE
WHEN [c].[Name] IS NOT NULL THEN [c].[Name]
END
SELECT [c].[Name]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
LEFT JOIN [Tags] AS [t0] ON ([g].[Nickname] = [t0].[GearNickName] OR ([g].[Nickname] IS NULL AND [t0].[GearNickName] IS NULL)) AND ([g].[SquadId] = [t0].[GearSquadId] OR ([g].[SquadId] IS NULL AND [t0].[GearSquadId] IS NULL))
Expand Down Expand Up @@ -1981,10 +1979,7 @@ public override async Task Optional_navigation_type_compensation_works_with_pred
SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[IssueDate], [t].[Note]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
WHERE CASE
WHEN [g].[HasSoulPatch] = CAST(1 AS bit) THEN CAST(1 AS bit)
ELSE [g].[HasSoulPatch]
END = CAST(0 AS bit)
WHERE [g].[HasSoulPatch] = CAST(0 AS bit)
""");
}

Expand All @@ -1997,10 +1992,7 @@ public override async Task Optional_navigation_type_compensation_works_with_pred
SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[IssueDate], [t].[Note]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
WHERE CASE
WHEN [g].[HasSoulPatch] = CAST(0 AS bit) THEN CAST(0 AS bit)
ELSE [g].[HasSoulPatch]
END = CAST(0 AS bit)
WHERE [g].[HasSoulPatch] = CAST(0 AS bit)
""");
}

Expand Down Expand Up @@ -3057,9 +3049,7 @@ public override async Task Select_null_conditional_with_inheritance(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[CommanderName] IS NOT NULL THEN [f].[CommanderName]
END
SELECT [f].[CommanderName]
FROM [Factions] AS [f]
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ public NorthwindFunctionsQuerySqlServer160Test(Fixture160 fixture, ITestOutputHe
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());

public override async Task Client_evaluation_of_uncorrelated_method_call(bool async)
{
await base.Client_evaluation_of_uncorrelated_method_call(async);

AssertSql(
"""
SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[UnitPrice] < 7.0 AND 10 < [o].[ProductID]
""");
}

public override async Task Sum_over_round_works_correctly_in_projection(bool async)
{
await base.Sum_over_round_works_correctly_in_projection(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2207,10 +2207,10 @@ SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END = [e].[NullableStringC] OR (CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NULL AND [e].[NullableStringC] IS NULL)
""");
}
Expand All @@ -2225,13 +2225,13 @@ SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE ([e].[NullableStringC] <> CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END OR [e].[NullableStringC] IS NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NULL) AND ([e].[NullableStringC] IS NOT NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringB]
ELSE [e].[NullableStringC]
END IS NOT NULL)
""");
}
Expand Down
Loading