From 0ca2015ead9d70b8d0ff70b2a87edd489b91bf7b Mon Sep 17 00:00:00 2001 From: klaus triendl Date: Sat, 18 Nov 2023 23:35:20 +0200 Subject: [PATCH] Possibility of passing arguments when constructing user-defined function --- dev/storage_base.h | 215 ++++++++++++++++++------------- include/sqlite_orm/sqlite_orm.h | 215 ++++++++++++++++++------------- tests/user_defined_functions.cpp | 53 +++++++- 3 files changed, 294 insertions(+), 189 deletions(-) diff --git a/dev/storage_base.h b/dev/storage_base.h index 6785a0fa0..55bdd31e2 100644 --- a/dev/storage_base.h +++ b/dev/storage_base.h @@ -236,8 +236,13 @@ namespace sqlite_orm { } /** - * Create a user-defined scalar function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined scalar SQL function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_scalar_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * T - function class. T must have operator() overload and static name function like this: * ``` @@ -255,59 +260,43 @@ namespace sqlite_orm { * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_scalar_function() { + template + void create_scalar_function(Args&&... args) { static_assert(is_scalar_udf_v, "F must be a scalar function"); - std::stringstream ss; - ss << F::name() << std::flush; - auto name = ss.str(); - using args_tuple = typename callable_arguments::args_tuple; - using return_type = typename callable_arguments::return_type; - constexpr auto argsCount = std::is_same>::value - ? -1 - : int(std::tuple_size::value); - this->scalarFunctions.push_back(make_udf_proxy( - std::move(name), - argsCount, - /* constructAt = */ - [](void* location) { - std::allocator allocator; - using traits = std::allocator_traits; - traits::construct(allocator, (F*)location); - }, - /* destroy = */ - obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), - /* call = */ - [](void* udfHandle, sqlite3_context* context, int argsCount, sqlite3_value** values) { - F& udf = *static_cast(udfHandle); - args_tuple argsTuple = tuple_from_values{}(values, argsCount); - auto result = polyfill::apply(udf, std::move(argsTuple)); - statement_binder().result(context, result); - })); - - if(this->connection->retain_count() > 0) { - sqlite3* db = this->connection->get(); - try_to_create_scalar_function(db, *this->scalarFunctions.back()); - } + this->create_scalar_function_impl(/* constructAt */ [args...](void* location) { + std::allocator allocator; + using traits = std::allocator_traits; + traits::construct(allocator, (F*)location, args...); + }); } #ifdef SQLITE_ORM_WITH_CPP20_ALIASES /** - * Create a user-defined scalar function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined scalar function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_scalar_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_scalar_function() { - return this->create_scalar_function>(); + template + void create_scalar_function(Args&&... args) { + return this->create_scalar_function>(std::forward(args)...); } #endif /** - * Create a user-defined aggregate function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined aggregate SQL function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_aggregate_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * T - function class. T must have step member function, fin member function and static name function like this: * ``` @@ -332,66 +321,33 @@ namespace sqlite_orm { * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_aggregate_function() { + template + void create_aggregate_function(Args&&... args) { static_assert(is_aggregate_udf_v, "F must be an aggregate function"); - std::stringstream ss; - ss << F::name() << std::flush; - auto name = ss.str(); - using args_tuple = typename callable_arguments::args_tuple; - using return_type = typename callable_arguments::return_type; - constexpr auto argsCount = std::is_same>::value - ? -1 - : int(std::tuple_size::value); - this->aggregateFunctions.push_back(make_udf_proxy( - std::move(name), - argsCount, - /* constructAt = */ - [](void* location) { - std::allocator allocator; - using traits = std::allocator_traits; - traits::construct(allocator, (F*)location); - }, - /* destroy = */ - obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), - /* step = */ - [](void* udfHandle, sqlite3_context*, int argsCount, sqlite3_value** values) { - F& udf = *static_cast(udfHandle); - args_tuple argsTuple = tuple_from_values{}(values, argsCount); -#if __cpp_lib_bind_front >= 201907L - std::apply(std::bind_front(&F::step, &udf), std::move(argsTuple)); -#else - polyfill::apply( - [&udf](auto&&... args) { - udf.step(std::forward(args)...); - }, - std::move(argsTuple)); -#endif - }, - /* finalCall = */ - [](void* udfHandle, sqlite3_context* context) { - F& udf = *static_cast(udfHandle); - auto result = udf.fin(); - statement_binder().result(context, result); - })); - - if(this->connection->retain_count() > 0) { - sqlite3* db = this->connection->get(); - try_to_create_aggregate_function(db, *this->aggregateFunctions.back()); - } + this->create_aggregate_function_impl(/* constructAt = */ + [args...](void* location) { + std::allocator allocator; + using traits = std::allocator_traits; + traits::construct(allocator, (F*)location, args...); + }); } #ifdef SQLITE_ORM_WITH_CPP20_ALIASES /** - * Create a user-defined aggregate function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined aggregate function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_aggregate_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_aggregate_function() { - return this->create_aggregate_function>(); + template + void create_aggregate_function(Args&&... args) { + return this->create_aggregate_function>(std::forward(args)...); } #endif @@ -692,6 +648,79 @@ namespace sqlite_orm { } } + template + void create_scalar_function_impl(std::function constructAt) { + std::stringstream ss; + ss << F::name() << std::flush; + auto name = ss.str(); + using args_tuple = typename callable_arguments::args_tuple; + using return_type = typename callable_arguments::return_type; + constexpr auto argsCount = std::is_same>::value + ? -1 + : int(std::tuple_size::value); + this->scalarFunctions.push_back(make_udf_proxy( + std::move(name), + argsCount, + std::move(constructAt), + /* destroy = */ + obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), + /* call = */ + [](void* udfHandle, sqlite3_context* context, int argsCount, sqlite3_value** values) { + F& udf = *static_cast(udfHandle); + args_tuple argsTuple = tuple_from_values{}(values, argsCount); + auto result = polyfill::apply(udf, std::move(argsTuple)); + statement_binder().result(context, result); + })); + + if(this->connection->retain_count() > 0) { + sqlite3* db = this->connection->get(); + try_to_create_scalar_function(db, *this->scalarFunctions.back()); + } + } + + template + void create_aggregate_function_impl(std::function constructAt) { + std::stringstream ss; + ss << F::name() << std::flush; + auto name = ss.str(); + using args_tuple = typename callable_arguments::args_tuple; + using return_type = typename callable_arguments::return_type; + constexpr auto argsCount = std::is_same>::value + ? -1 + : int(std::tuple_size::value); + this->aggregateFunctions.push_back(make_udf_proxy( + std::move(name), + argsCount, + std::move(constructAt), + /* destroy = */ + obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), + /* step = */ + [](void* udfHandle, sqlite3_context*, int argsCount, sqlite3_value** values) { + F& udf = *static_cast(udfHandle); + args_tuple argsTuple = tuple_from_values{}(values, argsCount); +#if __cpp_lib_bind_front >= 201907L + std::apply(std::bind_front(&F::step, &udf), std::move(argsTuple)); +#else + polyfill::apply( + [&udf](auto&&... args) { + udf.step(std::forward(args)...); + }, + std::move(argsTuple)); +#endif + }, + /* finalCall = */ + [](void* udfHandle, sqlite3_context* context) { + F& udf = *static_cast(udfHandle); + auto result = udf.fin(); + statement_binder().result(context, result); + })); + + if(this->connection->retain_count() > 0) { + sqlite3* db = this->connection->get(); + try_to_create_aggregate_function(db, *this->aggregateFunctions.back()); + } + } + void delete_function_impl(const std::string& name, std::vector>& functions) const { #if __cpp_lib_ranges >= 201911L diff --git a/include/sqlite_orm/sqlite_orm.h b/include/sqlite_orm/sqlite_orm.h index e13bbe52c..ab6233fe9 100644 --- a/include/sqlite_orm/sqlite_orm.h +++ b/include/sqlite_orm/sqlite_orm.h @@ -15199,8 +15199,13 @@ namespace sqlite_orm { } /** - * Create a user-defined scalar function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined scalar SQL function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_scalar_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * T - function class. T must have operator() overload and static name function like this: * ``` @@ -15218,59 +15223,43 @@ namespace sqlite_orm { * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_scalar_function() { + template + void create_scalar_function(Args&&... args) { static_assert(is_scalar_udf_v, "F must be a scalar function"); - std::stringstream ss; - ss << F::name() << std::flush; - auto name = ss.str(); - using args_tuple = typename callable_arguments::args_tuple; - using return_type = typename callable_arguments::return_type; - constexpr auto argsCount = std::is_same>::value - ? -1 - : int(std::tuple_size::value); - this->scalarFunctions.push_back(make_udf_proxy( - std::move(name), - argsCount, - /* constructAt = */ - [](void* location) { - std::allocator allocator; - using traits = std::allocator_traits; - traits::construct(allocator, (F*)location); - }, - /* destroy = */ - obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), - /* call = */ - [](void* udfHandle, sqlite3_context* context, int argsCount, sqlite3_value** values) { - F& udf = *static_cast(udfHandle); - args_tuple argsTuple = tuple_from_values{}(values, argsCount); - auto result = polyfill::apply(udf, std::move(argsTuple)); - statement_binder().result(context, result); - })); - - if(this->connection->retain_count() > 0) { - sqlite3* db = this->connection->get(); - try_to_create_scalar_function(db, *this->scalarFunctions.back()); - } + this->create_scalar_function_impl(/* constructAt */ [args...](void* location) { + std::allocator allocator; + using traits = std::allocator_traits; + traits::construct(allocator, (F*)location, args...); + }); } #ifdef SQLITE_ORM_WITH_CPP20_ALIASES /** - * Create a user-defined scalar function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined scalar function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_scalar_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_scalar_function() { - return this->create_scalar_function>(); + template + void create_scalar_function(Args&&... args) { + return this->create_scalar_function>(std::forward(args)...); } #endif /** - * Create a user-defined aggregate function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined aggregate SQL function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_aggregate_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * T - function class. T must have step member function, fin member function and static name function like this: * ``` @@ -15295,66 +15284,33 @@ namespace sqlite_orm { * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_aggregate_function() { + template + void create_aggregate_function(Args&&... args) { static_assert(is_aggregate_udf_v, "F must be an aggregate function"); - std::stringstream ss; - ss << F::name() << std::flush; - auto name = ss.str(); - using args_tuple = typename callable_arguments::args_tuple; - using return_type = typename callable_arguments::return_type; - constexpr auto argsCount = std::is_same>::value - ? -1 - : int(std::tuple_size::value); - this->aggregateFunctions.push_back(make_udf_proxy( - std::move(name), - argsCount, - /* constructAt = */ - [](void* location) { - std::allocator allocator; - using traits = std::allocator_traits; - traits::construct(allocator, (F*)location); - }, - /* destroy = */ - obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), - /* step = */ - [](void* udfHandle, sqlite3_context*, int argsCount, sqlite3_value** values) { - F& udf = *static_cast(udfHandle); - args_tuple argsTuple = tuple_from_values{}(values, argsCount); -#if __cpp_lib_bind_front >= 201907L - std::apply(std::bind_front(&F::step, &udf), std::move(argsTuple)); -#else - polyfill::apply( - [&udf](auto&&... args) { - udf.step(std::forward(args)...); - }, - std::move(argsTuple)); -#endif - }, - /* finalCall = */ - [](void* udfHandle, sqlite3_context* context) { - F& udf = *static_cast(udfHandle); - auto result = udf.fin(); - statement_binder().result(context, result); - })); - - if(this->connection->retain_count() > 0) { - sqlite3* db = this->connection->get(); - try_to_create_aggregate_function(db, *this->aggregateFunctions.back()); - } + this->create_aggregate_function_impl(/* constructAt = */ + [args...](void* location) { + std::allocator allocator; + using traits = std::allocator_traits; + traits::construct(allocator, (F*)location, args...); + }); } #ifdef SQLITE_ORM_WITH_CPP20_ALIASES /** - * Create a user-defined aggregate function. - * Can be called at any time no matter whether the database connection is opened or no. + * Create an application-defined aggregate function. + * Can be called at any time no matter whether the database connection is opened or not. + * + * Note: `create_aggregate_function()` merely creates a closure to generate an instance of the scalar function object, + * together with a copy of the passed initialization arguments. + * An instance of the function object is repeatedly recreated for each result row, + * ensuring that the calculations always start with freshly initialized values. * * Attention: Currently, a function's name must not contain white-space characters, because it doesn't get quoted. */ - template - void create_aggregate_function() { - return this->create_aggregate_function>(); + template + void create_aggregate_function(Args&&... args) { + return this->create_aggregate_function>(std::forward(args)...); } #endif @@ -15655,6 +15611,79 @@ namespace sqlite_orm { } } + template + void create_scalar_function_impl(std::function constructAt) { + std::stringstream ss; + ss << F::name() << std::flush; + auto name = ss.str(); + using args_tuple = typename callable_arguments::args_tuple; + using return_type = typename callable_arguments::return_type; + constexpr auto argsCount = std::is_same>::value + ? -1 + : int(std::tuple_size::value); + this->scalarFunctions.push_back(make_udf_proxy( + std::move(name), + argsCount, + std::move(constructAt), + /* destroy = */ + obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), + /* call = */ + [](void* udfHandle, sqlite3_context* context, int argsCount, sqlite3_value** values) { + F& udf = *static_cast(udfHandle); + args_tuple argsTuple = tuple_from_values{}(values, argsCount); + auto result = polyfill::apply(udf, std::move(argsTuple)); + statement_binder().result(context, result); + })); + + if(this->connection->retain_count() > 0) { + sqlite3* db = this->connection->get(); + try_to_create_scalar_function(db, *this->scalarFunctions.back()); + } + } + + template + void create_aggregate_function_impl(std::function constructAt) { + std::stringstream ss; + ss << F::name() << std::flush; + auto name = ss.str(); + using args_tuple = typename callable_arguments::args_tuple; + using return_type = typename callable_arguments::return_type; + constexpr auto argsCount = std::is_same>::value + ? -1 + : int(std::tuple_size::value); + this->aggregateFunctions.push_back(make_udf_proxy( + std::move(name), + argsCount, + std::move(constructAt), + /* destroy = */ + obtain_xdestroy_for(udf_proxy::destruct_only_deleter{}), + /* step = */ + [](void* udfHandle, sqlite3_context*, int argsCount, sqlite3_value** values) { + F& udf = *static_cast(udfHandle); + args_tuple argsTuple = tuple_from_values{}(values, argsCount); +#if __cpp_lib_bind_front >= 201907L + std::apply(std::bind_front(&F::step, &udf), std::move(argsTuple)); +#else + polyfill::apply( + [&udf](auto&&... args) { + udf.step(std::forward(args)...); + }, + std::move(argsTuple)); +#endif + }, + /* finalCall = */ + [](void* udfHandle, sqlite3_context* context) { + F& udf = *static_cast(udfHandle); + auto result = udf.fin(); + statement_binder().result(context, result); + })); + + if(this->connection->retain_count() > 0) { + sqlite3* db = this->connection->get(); + try_to_create_aggregate_function(db, *this->aggregateFunctions.back()); + } + } + void delete_function_impl(const std::string& name, std::vector>& functions) const { #if __cpp_lib_ranges >= 201911L diff --git a/tests/user_defined_functions.cpp b/tests/user_defined_functions.cpp index 0ac76a3f4..1b9e23f8a 100644 --- a/tests/user_defined_functions.cpp +++ b/tests/user_defined_functions.cpp @@ -177,7 +177,7 @@ struct alignas(2 * __STDCPP_DEFAULT_NEW_ALIGNMENT__) OverAlignedScalarFunction { } static const char* name() { - return "OVERALIGNED"; + return "OVERALIGNED1"; } }; @@ -187,16 +187,47 @@ struct alignas(2 * __STDCPP_DEFAULT_NEW_ALIGNMENT__) OverAlignedAggregateFunctio void step(double arg) { sum += arg; } - int fin() const { + double fin() const { return sum; } static const char* name() { - return "OVERALIGNED"; + return "OVERALIGNED2"; } }; #endif +struct NonDefaultCtorScalarFunction { + const int multiplier; + + NonDefaultCtorScalarFunction(int multiplier) : multiplier{multiplier} {} + + int operator()(int arg) const { + return multiplier * arg; + } + + static const char* name() { + return "CTORTEST1"; + } +}; + +struct NonDefaultCtorAggregateFunction { + int sum; + + NonDefaultCtorAggregateFunction(int initialValue) : sum{initialValue} {} + + void step(int arg) { + sum += arg; + } + int fin() const { + return sum; + } + + static const char* name() { + return "CTORTEST2"; + } +}; + TEST_CASE("custom functions") { using Catch::Matchers::ContainsSubstring; @@ -348,4 +379,20 @@ TEST_CASE("custom functions") { REQUIRE_NOTHROW(storage.delete_aggregate_function()); } #endif + + storage.create_scalar_function(42); + { + auto rows = storage.select(func(1)); + decltype(rows) expected{42}; + REQUIRE(rows == expected); + } + storage.delete_scalar_function(); + + storage.create_aggregate_function(42); + { + auto rows = storage.select(func(1)); + decltype(rows) expected{43}; + REQUIRE(rows == expected); + } + storage.delete_aggregate_function(); }