From 8a0fca5b74c32411c6e5f2255b290362cd4ba51b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sat, 21 Sep 2024 21:38:50 +0200 Subject: [PATCH] Support arbitrary enumerables in NpgsqlArrayConverter (#3290) Closes #3286 --- .../Mapping/NpgsqlArrayTypeMapping.cs | 4 +- .../ValueConversion/NpgsqlArrayConverter.cs | 214 +++++++++++++----- .../PrimitiveCollectionsQueryNpgsqlTest.cs | 20 ++ 3 files changed, 176 insertions(+), 62 deletions(-) diff --git a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs index 6b70853b3..6672f0520 100644 --- a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs +++ b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs @@ -222,8 +222,10 @@ public override DbParameter CreateParameter( // In queries which compose non-server-correlated LINQ operators over an array parameter (e.g. Where(b => ids.Skip(1)...) we // get an enumerable parameter value that isn't an array/list - but those aren't supported at the Npgsql ADO level. // Detect this here and evaluate the enumerable to get a fully materialized List. + // Note that when we have a value converter (e.g. for HashSet), we don't want to convert it to a List, since the value converter + // expects the original type. // TODO: Make Npgsql support IList<> instead of only arrays and List<> - if (value is not null && !value.GetType().IsArrayOrGenericList()) + if (value is not null && Converter is null && !value.GetType().IsArrayOrGenericList()) { switch (value) { diff --git a/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs b/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs index f04221b5a..982d58eaa 100644 --- a/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs +++ b/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs @@ -105,93 +105,185 @@ private static Expression> ArrayConversionExpression(); - var variables = new List(4) - { - output, - lengthVariable, - }; + var variables = new List { output, lengthVariable }; Expression getInputLength; - Func indexer; + Func? indexer; - if (typeof(TInput).IsArray) - { - getInputLength = ArrayLength(input); - indexer = i => ArrayAccess(input, i); - } - else if (typeof(TInput).IsGenericType - && typeof(TInput).GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>))) - { - getInputLength = Property( - input, - typeof(TInput).GetProperty("Count") - // If TInput is an interface (IList), its Count property needs to be found on ICollection - ?? typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!); - indexer = i => Property(input, input.Type.FindIndexerProperty()!, i); - } - else + // The conversion is going to depend on what kind of input we have: array, list, collection, or arbitrary IEnumerable. + // For array/list we can get the length and index inside, so we can do an efficient for loop. + // For other ICollections (e.g. HashSet) we can get the length (and so pre-allocate the output), but we can't index; so we + // get an enumerator and use that. + // For arbitrary IEnumerable, we can't get the length so we can't preallocate output arrays; so we to call ToList() on it and then + // process that (note that we could avoid that when the output is a List rather than an array). + var inputInterfaces = input.Type.GetInterfaces(); + switch (input.Type) { - // Input collection isn't typed as an ICollection; it can be *typed* as an IEnumerable, but we only support concrete - // instances being ICollection. Emit code that casts the type at runtime. - var iListType = typeof(IList<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]); + // Input is typed as an array - we can get its length and index into it + case { IsArray: true }: + getInputLength = ArrayLength(input); + indexer = i => ArrayAccess(input, i); + break; + + // Input is typed as an IList - we can get its length and index into it + case { IsGenericType: true } when inputInterfaces.Append(input.Type) + .Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>)): + { + getInputLength = Property( + input, + input.Type.GetProperty("Count") + // If TInput is an interface (IList), its Count property needs to be found on ICollection + ?? typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!); + indexer = i => Property(input, input.Type.FindIndexerProperty()!, i); + break; + } - var convertedInput = Variable(iListType, "convertedInput"); - variables.Add(convertedInput); + // Input is typed as an ICollection - we can get its length, but we can't index into it + case { IsGenericType: true } when inputInterfaces.Append(input.Type) + .Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(ICollection<>)): + { + getInputLength = Property( + input, typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!); + indexer = null; + break; + } - expressions.Add(Assign(convertedInput, Convert(input, convertedInput.Type))); + // Input is typed as an IEnumerable - we can't get its length, and we can't index into it. + // All we can do is call ToList() on it and then process that. + case { IsGenericType: true } when inputInterfaces.Append(input.Type) + .Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)): + { + // TODO: In theory, we could add runtime checks for array/list/collection, downcast for those cases and include + // the logic from the other switch cases here. + convertedInput = Variable(typeof(List<>).MakeGenericType(inputElementType), "convertedInput"); + variables.Add(convertedInput); + expressions.Add( + Assign( + convertedInput, + Call(typeof(Enumerable).GetMethod(nameof(Enumerable.ToList))!.MakeGenericMethod(inputElementType), input))); + getInputLength = Property(convertedInput, convertedInput.Type.GetProperty("Count")!); + indexer = i => Property(convertedInput, convertedInput.Type.FindIndexerProperty()!, i); + break; + } - // TODO: Check and properly throw for non-IList, e.g. set - getInputLength = Property( - convertedInput, typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!); - indexer = i => Property(convertedInput, iListType.FindIndexerProperty()!, i); + default: + throw new NotSupportedException($"Array value converter input type must be an IEnumerable, but is {typeof(TInput)}"); } expressions.AddRange( [ // Get the length of the input array or list - // var length = input.Length; - Assign(lengthVariable, getInputLength), - - // Allocate an output array or list - // var result = new int[length]; - Assign( - output, typeof(TConcreteOutput).IsArray - ? NewArrayBounds(outputElementType, lengthVariable) - : typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength - ? New(ctorWithLength, lengthVariable) - : New(typeof(TConcreteOutput).GetConstructor([])!)), - - // Loop over the elements, applying the element converter on them one by one - // for (var i = 0; i < length; i++) - // { - // result[i] = input[i]; - // } + // var length = input.Length; + Assign(lengthVariable, getInputLength), + + // Allocate an output array or list + // var result = new int[length]; + Assign( + output, typeof(TConcreteOutput).IsArray + ? NewArrayBounds(outputElementType, lengthVariable) + : typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength + ? New(ctorWithLength, lengthVariable) + : New(typeof(TConcreteOutput).GetConstructor([])!)) + ]); + + if (indexer is not null) + { + // Good case: the input is an array or list, so we can index into it. Generate code for an efficient for loop, which applies + // the element converter on each element. + // for (var i = 0; i < length; i++) + // { + // result[i] = input[i]; + // } + var counter = Parameter(typeof(int), "i"); + + expressions.Add( ForLoop( - loopVar: loopVariable, + loopVar: counter, initValue: Constant(0), - condition: LessThan(loopVariable, lengthVariable), - increment: AddAssign(loopVariable, Constant(1)), + condition: LessThan(counter, lengthVariable), + increment: AddAssign(counter, Constant(1)), loopContent: typeof(TConcreteOutput).IsArray ? Assign( - ArrayAccess(output, loopVariable), + ArrayAccess(output, counter), elementConversionExpression is null - ? indexer(loopVariable) - : Invoke(elementConversionExpression, indexer(loopVariable))) + ? indexer(counter) + : Invoke(elementConversionExpression, indexer(counter))) : Call( output, typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!, elementConversionExpression is null - ? indexer(loopVariable) - : Invoke(elementConversionExpression, indexer(loopVariable)))), - output - ]); + ? indexer(counter) + : Invoke(elementConversionExpression, indexer(counter))))); + } + else + { + // Bad case: the input is not an array or list, but is a collection (e.g. HashSet), so we can't index into it. + // Generate code for a less efficient enumerator-based iteration. + // enumerator = input.GetEnumerator(); + // counter = 0; + // while (enumerator.MoveNext()) + // { + // output[counter] = enumerator.Current; + // counter++; + // } + var enumerableType = typeof(IEnumerable<>).MakeGenericType(inputElementType); + var enumeratorType = typeof(IEnumerator<>).MakeGenericType(inputElementType); + + var enumeratorVariable = Variable(enumeratorType, "enumerator"); + var counterVariable = Variable(typeof(int), "variable"); + variables.AddRange([enumeratorVariable, counterVariable]); + + expressions.AddRange( + [ + // enumerator = input.GetEnumerator(); + Assign(enumeratorVariable, Call(input, enumerableType.GetMethod(nameof(IEnumerable.GetEnumerator))!)), + + // counter = 0; + Assign(counterVariable, Constant(0)) + ]); + + var breakLabel = Label("LoopBreak"); + + var loop = + Loop( + IfThenElse( + Equal(Call(enumeratorVariable, typeof(IEnumerator).GetMethod(nameof(IEnumerator.MoveNext))!), Constant(true)), + Block( + typeof(TConcreteOutput).IsArray + // output[counter] = enumerator.Current; + ? Assign( + ArrayAccess(output, counterVariable), + elementConversionExpression is null + ? Property(enumeratorVariable, "Current") + : Invoke(elementConversionExpression, Property(enumeratorVariable, "Current"))) + // output.Add(enumerator.Current); + : Call( + output, + typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!, + elementConversionExpression is null + ? Property(enumeratorVariable, "Current") + : Invoke(elementConversionExpression, Property(enumeratorVariable, "Current"))), + + // counter++; + AddAssign(counterVariable, Constant(1))), + Break(breakLabel)), + breakLabel); + + expressions.Add( + TryFinally( + loop, + Call(enumeratorVariable, typeof(IDisposable).GetMethod(nameof(IDisposable.Dispose))!))); + } + + // return output; + expressions.Add(output); return Lambda>( // First, check if the given array value is null and return null immediately if so diff --git a/test/EFCore.PG.FunctionalTests/Query/PrimitiveCollectionsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/PrimitiveCollectionsQueryNpgsqlTest.cs index 7802ed3e7..b9a92210d 100644 --- a/test/EFCore.PG.FunctionalTests/Query/PrimitiveCollectionsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/PrimitiveCollectionsQueryNpgsqlTest.cs @@ -508,6 +508,26 @@ WHERE NOT (p."Int" = ANY (@__ints_0) AND p."Int" = ANY (@__ints_0) IS NOT NULL) """); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Parameter_collection_HashSet_with_value_converter_Contains(bool async) + { + HashSet enums = [MyEnum.Value1, MyEnum.Value4]; + + await AssertQuery( + async, + ss => ss.Set().Where(c => enums.Contains(c.Enum))); + + AssertSql( + """ +@__enums_0={ '0', '3' } (DbType = Object) + +SELECT p."Id", p."Bool", p."Bools", p."DateTime", p."DateTimes", p."Enum", p."Enums", p."Int", p."Ints", p."NullableInt", p."NullableInts", p."NullableString", p."NullableStrings", p."String", p."Strings" +FROM "PrimitiveCollectionsEntity" AS p +WHERE p."Enum" = ANY (@__enums_0) +"""); + } + public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async) { await base.Parameter_collection_of_ints_Contains_nullable_int(async);