Skip to content

Commit

Permalink
Updated a few things concerning user-defined function creation
Browse files Browse the repository at this point in the history
  • Loading branch information
trueqbit committed Nov 14, 2023
1 parent cb4bcac commit 60fff04
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 34 deletions.
37 changes: 20 additions & 17 deletions dev/storage_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ namespace sqlite_orm {
*/
template<class F>
void create_scalar_function() {
static_assert(is_scalar_udf_v<F>, "F cannot be an aggregate function");
static_assert(is_scalar_udf_v<F>, "F must be a scalar function");

std::stringstream ss;
ss << F::name() << std::flush;
Expand All @@ -270,10 +270,10 @@ namespace sqlite_orm {
},
/* call = */
[](sqlite3_context* context, void* udfHandle, int argsCount, sqlite3_value** values) {
F& function = *static_cast<F*>(udfHandle);
F& udf = *static_cast<F*>(udfHandle);
args_tuple argsTuple;
values_to_tuple{}(values, argsTuple, argsCount);
auto result = call(function, std::move(argsTuple));
auto result = call(udf, std::move(argsTuple));
statement_binder<return_type>().result(context, result);
},
delete_function_callback<F>));
Expand Down Expand Up @@ -318,7 +318,7 @@ namespace sqlite_orm {
*/
template<class F>
void create_aggregate_function() {
static_assert(is_aggregate_udf_v<F>, "F cannot be a scalar function");
static_assert(is_aggregate_udf_v<F>, "F must be an aggregate function");

std::stringstream ss;
ss << F::name() << std::flush;
Expand All @@ -337,15 +337,15 @@ namespace sqlite_orm {
},
/* step = */
[](sqlite3_context*, void* udfHandle, int argsCount, sqlite3_value** values) {
F& function = *static_cast<F*>(udfHandle);
F& udf = *static_cast<F*>(udfHandle);
args_tuple argsTuple;
values_to_tuple{}(values, argsTuple, argsCount);
call(function, &F::step, std::move(argsTuple));
call(udf, &F::step, std::move(argsTuple));
},
/* finalCall = */
[](sqlite3_context* context, void* udfHandle) {
F& function = *static_cast<F*>(udfHandle);
auto result = function.fin();
F& udf = *static_cast<F*>(udfHandle);
auto result = udf.fin();
statement_binder<return_type>().result(context, result);
},
delete_function_callback<F>));
Expand Down Expand Up @@ -689,7 +689,7 @@ namespace sqlite_orm {
}
}

void try_to_create_function(sqlite3* db, scalar_udf_proxy& udfProxy) {
static void try_to_create_function(sqlite3* db, scalar_udf_proxy& udfProxy) {
int rc = sqlite3_create_function_v2(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -704,7 +704,7 @@ namespace sqlite_orm {
}
}

void try_to_create_function(sqlite3* db, aggregate_udf_proxy& udfProxy) {
static void try_to_create_function(sqlite3* db, aggregate_udf_proxy& udfProxy) {
int rc = sqlite3_create_function(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -723,8 +723,11 @@ namespace sqlite_orm {
auto* udfProxy = static_cast<aggregate_udf_proxy*>(sqlite3_user_data(context));
// allocate or fetch pointer handle to user-defined function
void* aggregateStateMem = sqlite3_aggregate_context(context, sizeof(void**));
void* udfHandle = *static_cast<void**>(aggregateStateMem);
void*& udfHandle = *static_cast<void**>(aggregateStateMem);
if(udfHandle == nullptr) {
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
udfHandle = udfProxy->create();
}
udfProxy->step(context, udfHandle, argsCount, values);
Expand All @@ -734,7 +737,7 @@ namespace sqlite_orm {
auto* udfProxy = static_cast<aggregate_udf_proxy*>(sqlite3_user_data(context));
// allocate or fetch pointer handle to user-defined function
void* aggregateStateMem = sqlite3_aggregate_context(context, sizeof(void**));
void* udfHandle = *static_cast<void**>(aggregateStateMem);
void*& udfHandle = *static_cast<void**>(aggregateStateMem);
// note: it is possible that the 'step' function was never called
if(udfHandle == nullptr) {
udfHandle = udfProxy->create();
Expand All @@ -745,17 +748,17 @@ namespace sqlite_orm {

static void scalar_function_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
auto udfProxy = static_cast<scalar_udf_proxy*>(sqlite3_user_data(context));
const std::unique_ptr<void, xdestroy_fn_t> udfHandleGuard(udfProxy->create(), udfProxy->destroy);
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
udfProxy->run(context, udfProxy, argsCount, values);
const std::unique_ptr<void, xdestroy_fn_t> udfHandle(udfProxy->create(), udfProxy->destroy);
udfProxy->run(context, udfHandle.get(), argsCount, values);
}

template<class F>
static void delete_function_callback(void* pointer) {
auto fPointer = static_cast<F*>(pointer);
delete fPointer;
static void delete_function_callback(void* udfHandle) {
F* udf = static_cast<F*>(udfHandle);
delete udf;
}

std::string current_time(sqlite3* db) {
Expand Down
37 changes: 20 additions & 17 deletions include/sqlite_orm/sqlite_orm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15068,7 +15068,7 @@ namespace sqlite_orm {
*/
template<class F>
void create_scalar_function() {
static_assert(is_scalar_udf_v<F>, "F cannot be an aggregate function");
static_assert(is_scalar_udf_v<F>, "F must be a scalar function");

std::stringstream ss;
ss << F::name() << std::flush;
Expand All @@ -15087,10 +15087,10 @@ namespace sqlite_orm {
},
/* call = */
[](sqlite3_context* context, void* udfHandle, int argsCount, sqlite3_value** values) {
F& function = *static_cast<F*>(udfHandle);
F& udf = *static_cast<F*>(udfHandle);
args_tuple argsTuple;
values_to_tuple{}(values, argsTuple, argsCount);
auto result = call(function, std::move(argsTuple));
auto result = call(udf, std::move(argsTuple));
statement_binder<return_type>().result(context, result);
},
delete_function_callback<F>));
Expand Down Expand Up @@ -15135,7 +15135,7 @@ namespace sqlite_orm {
*/
template<class F>
void create_aggregate_function() {
static_assert(is_aggregate_udf_v<F>, "F cannot be a scalar function");
static_assert(is_aggregate_udf_v<F>, "F must be an aggregate function");

std::stringstream ss;
ss << F::name() << std::flush;
Expand All @@ -15154,15 +15154,15 @@ namespace sqlite_orm {
},
/* step = */
[](sqlite3_context*, void* udfHandle, int argsCount, sqlite3_value** values) {
F& function = *static_cast<F*>(udfHandle);
F& udf = *static_cast<F*>(udfHandle);
args_tuple argsTuple;
values_to_tuple{}(values, argsTuple, argsCount);
call(function, &F::step, std::move(argsTuple));
call(udf, &F::step, std::move(argsTuple));
},
/* finalCall = */
[](sqlite3_context* context, void* udfHandle) {
F& function = *static_cast<F*>(udfHandle);
auto result = function.fin();
F& udf = *static_cast<F*>(udfHandle);
auto result = udf.fin();
statement_binder<return_type>().result(context, result);
},
delete_function_callback<F>));
Expand Down Expand Up @@ -15506,7 +15506,7 @@ namespace sqlite_orm {
}
}

void try_to_create_function(sqlite3* db, scalar_udf_proxy& udfProxy) {
static void try_to_create_function(sqlite3* db, scalar_udf_proxy& udfProxy) {
int rc = sqlite3_create_function_v2(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -15521,7 +15521,7 @@ namespace sqlite_orm {
}
}

void try_to_create_function(sqlite3* db, aggregate_udf_proxy& udfProxy) {
static void try_to_create_function(sqlite3* db, aggregate_udf_proxy& udfProxy) {
int rc = sqlite3_create_function(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -15540,8 +15540,11 @@ namespace sqlite_orm {
auto* udfProxy = static_cast<aggregate_udf_proxy*>(sqlite3_user_data(context));
// allocate or fetch pointer handle to user-defined function
void* aggregateStateMem = sqlite3_aggregate_context(context, sizeof(void**));
void* udfHandle = *static_cast<void**>(aggregateStateMem);
void*& udfHandle = *static_cast<void**>(aggregateStateMem);
if(udfHandle == nullptr) {
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
udfHandle = udfProxy->create();
}
udfProxy->step(context, udfHandle, argsCount, values);
Expand All @@ -15551,7 +15554,7 @@ namespace sqlite_orm {
auto* udfProxy = static_cast<aggregate_udf_proxy*>(sqlite3_user_data(context));
// allocate or fetch pointer handle to user-defined function
void* aggregateStateMem = sqlite3_aggregate_context(context, sizeof(void**));
void* udfHandle = *static_cast<void**>(aggregateStateMem);
void*& udfHandle = *static_cast<void**>(aggregateStateMem);
// note: it is possible that the 'step' function was never called
if(udfHandle == nullptr) {
udfHandle = udfProxy->create();
Expand All @@ -15562,17 +15565,17 @@ namespace sqlite_orm {

static void scalar_function_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
auto udfProxy = static_cast<scalar_udf_proxy*>(sqlite3_user_data(context));
const std::unique_ptr<void, xdestroy_fn_t> udfHandleGuard(udfProxy->create(), udfProxy->destroy);
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
udfProxy->run(context, udfProxy, argsCount, values);
const std::unique_ptr<void, xdestroy_fn_t> udfHandle(udfProxy->create(), udfProxy->destroy);
udfProxy->run(context, udfHandle.get(), argsCount, values);
}

template<class F>
static void delete_function_callback(void* pointer) {
auto fPointer = static_cast<F*>(pointer);
delete fPointer;
static void delete_function_callback(void* udfHandle) {
F* udf = static_cast<F*>(udfHandle);
delete udf;
}

std::string current_time(sqlite3* db) {
Expand Down

0 comments on commit 60fff04

Please sign in to comment.