Skip to content

Commit

Permalink
Generalize simplification as suggested
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 15, 2024
1 parent f4472db commit 79c2d24
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
45 changes: 28 additions & 17 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ public virtual SqlExpression Case(
}

// Simplify:
// a == null ? null : a -> a
// a != null ? a : null -> a
// a == b ? b : a -> a
// a != b ? a : b -> a
// And lift:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
Expand All @@ -838,28 +838,39 @@ public virtual SqlExpression Case(
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
Result: var result
}
]
&& binary.OperatorType switch
{
ExpressionType.Equal when result is SqlConstantExpression { Value: null } && elseResult is not null => elseResult,
ExpressionType.NotEqual when elseResult is null or SqlConstantExpression { Value: null } => result,
_ => null
} is SqlExpression conditionalResult)
])
{
var (left, right) = (binary.Left, binary.Right);

if (left.Equals(conditionalResult))
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasonining below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal
? (result, elseResult ?? Constant(null, result.Type, result.TypeMapping))
: (elseResult ?? Constant(null, result.Type, result.TypeMapping), result);

if (left.Equals(ifNotEqual))
{
return right is SqlConstantExpression { Value: null }
? left
: Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
switch (ifEqual)
{
// a == b ? b : a -> a
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
// a == b ? null : a -> NULLIF(a, b)
case var _ when ifEqual.Equals(right):
return left;
}
}

if (right.Equals(conditionalResult))
if (right.Equals(ifNotEqual))
{
return left is SqlConstantExpression { Value: null }
? right
: Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
switch (ifEqual)
{
// b == a ? b : a -> a
case SqlConstantExpression { Value: null }:
return Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
// b == a ? null : a -> NULLIF(a, b)
case var _ when ifEqual.Equals(left):
return right;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ public virtual Task Conditional_simplifiable_equality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == null ? null : x.Int) > 1));
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == 9 ? 9 : x.Int) > 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Conditional_simplifiable_inequality(bool async)
=> AssertQuery(
async,
// ReSharper disable once MergeConditionalExpression
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != null ? x.Int : null) > 1));
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != 8 ? x.Int : 8) > 1));

// In relational providers, x == a ? null : x ("un-coalescing conditional") is translated to SQL NULLIF

Expand Down

0 comments on commit 79c2d24

Please sign in to comment.