Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements getter functions for LFMCMC #49

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ S3method(get_n_viruses,epiworld_model)
S3method(get_name,epiworld_model)
S3method(get_ndays,epiworld_model)
S3method(get_param,epiworld_model)
S3method(get_params_mean,epiworld_lfmcmc)
S3method(get_reproductive_number,epiworld_model)
S3method(get_states,epiworld_model)
S3method(get_stats_mean,epiworld_lfmcmc)
S3method(get_today_total,epiworld_model)
S3method(get_transition_probability,epiworld_model)
S3method(get_transmissions,epiworld_diffnet)
Expand Down Expand Up @@ -150,9 +152,11 @@ export(get_name_virus)
export(get_ndays)
export(get_network)
export(get_param)
export(get_params_mean)
export(get_reproductive_number)
export(get_state)
export(get_states)
export(get_stats_mean)
export(get_today_total)
export(get_tool)
export(get_transition_probability)
Expand Down
26 changes: 26 additions & 0 deletions R/LFMCMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
#' set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness"))
#'
#' print(lfmcmc_model)
#'
#' get_stats_mean(lfmcmc_model)
#' get_params_mean(lfmcmc_model)
#'
#' @export
LFMCMC <- function(model) {
if (!inherits(model, "epiworld_model"))
Expand Down Expand Up @@ -205,6 +209,28 @@ set_stats_names.epiworld_lfmcmc <- function(lfmcmc, names) {
invisible(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @returns The param means for the given lfmcmc model
#' @export
get_params_mean <- function(lfmcmc) UseMethod("get_params_mean")

#' @export
get_params_mean.epiworld_lfmcmc <- function(lfmcmc) {
get_params_mean_cpp(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @returns The stats means for the given lfmcmc model
#' @export
get_stats_mean <- function(lfmcmc) UseMethod("get_stats_mean")

#' @export
get_stats_mean.epiworld_lfmcmc <- function(lfmcmc) {
get_stats_mean_cpp(lfmcmc)
}

#' @rdname LFMCMC
#' @param x LFMCMC model to print
#' @param ... Ignored
Expand Down
8 changes: 8 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ set_stats_names_cpp <- function(lfmcmc, names) {
.Call(`_epiworldR_set_stats_names_cpp`, lfmcmc, names)
}

get_params_mean_cpp <- function(lfmcmc) {
.Call(`_epiworldR_get_params_mean_cpp`, lfmcmc)
}

get_stats_mean_cpp <- function(lfmcmc) {
.Call(`_epiworldR_get_stats_mean_cpp`, lfmcmc)
}

print_lfmcmc_cpp <- function(lfmcmc) {
.Call(`_epiworldR_print_lfmcmc_cpp`, lfmcmc)
}
Expand Down
3 changes: 3 additions & 0 deletions inst/tinytest/test-lfmcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ expect_silent(set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness")

expect_stdout(print(lfmcmc_model))

expect_equal(get_stats_mean(lfmcmc_model), c(4.45, 2.6135, 992.4365))
expect_equal(get_params_mean(lfmcmc_model), c(11.58421, 18.96851), tolerance = 0.00001)

# Check LFMCMC using factory functions -----------------------------------------
expect_silent(use_proposal_norm_reflective(lfmcmc_model))
expect_silent(use_kernel_fun_gaussian(lfmcmc_model))
Expand Down
14 changes: 14 additions & 0 deletions man/LFMCMC.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epiworld-methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,20 @@ extern "C" SEXP _epiworldR_set_stats_names_cpp(SEXP lfmcmc, SEXP names) {
END_CPP11
}
// lfmcmc.cpp
cpp11::writable::doubles get_params_mean_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_get_params_mean_cpp(SEXP lfmcmc) {
BEGIN_CPP11
return cpp11::as_sexp(get_params_mean_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc)));
END_CPP11
}
// lfmcmc.cpp
cpp11::writable::doubles get_stats_mean_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_get_stats_mean_cpp(SEXP lfmcmc) {
BEGIN_CPP11
return cpp11::as_sexp(get_stats_mean_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc)));
END_CPP11
}
// lfmcmc.cpp
SEXP print_lfmcmc_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_print_lfmcmc_cpp(SEXP lfmcmc) {
BEGIN_CPP11
Expand Down Expand Up @@ -1068,9 +1082,11 @@ static const R_CallMethodDef CallEntries[] = {
{"_epiworldR_get_ndays_cpp", (DL_FUNC) &_epiworldR_get_ndays_cpp, 1},
{"_epiworldR_get_network_cpp", (DL_FUNC) &_epiworldR_get_network_cpp, 1},
{"_epiworldR_get_param_cpp", (DL_FUNC) &_epiworldR_get_param_cpp, 2},
{"_epiworldR_get_params_mean_cpp", (DL_FUNC) &_epiworldR_get_params_mean_cpp, 1},
{"_epiworldR_get_reproductive_number_cpp", (DL_FUNC) &_epiworldR_get_reproductive_number_cpp, 1},
{"_epiworldR_get_state_agent_cpp", (DL_FUNC) &_epiworldR_get_state_agent_cpp, 1},
{"_epiworldR_get_states_cpp", (DL_FUNC) &_epiworldR_get_states_cpp, 1},
{"_epiworldR_get_stats_mean_cpp", (DL_FUNC) &_epiworldR_get_stats_mean_cpp, 1},
{"_epiworldR_get_today_total_cpp", (DL_FUNC) &_epiworldR_get_today_total_cpp, 1},
{"_epiworldR_get_tool_model_cpp", (DL_FUNC) &_epiworldR_get_tool_model_cpp, 2},
{"_epiworldR_get_transition_probability_cpp", (DL_FUNC) &_epiworldR_get_transition_probability_cpp, 1},
Expand Down
23 changes: 20 additions & 3 deletions src/lfmcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "cpp11/external_pointer.hpp"
#include "cpp11/r_vector.hpp"
#include "cpp11/sexp.hpp"
#include "cpp11/doubles.hpp"
#include <iostream>

#include "epiworld-common.h"
Expand Down Expand Up @@ -145,12 +146,12 @@ SEXP set_summary_fun_cpp(
LFMCMC<TData_default>*
) -> void {

if (res.size() == 0u)
res.resize(dat.size());

auto dat_int = cpp11::integers(dat);
auto res_tmp = cpp11::integers(fun(dat_int));

if (res.size() == 0u)
res.resize(res_tmp.size());

std::copy(res_tmp.begin(), res_tmp.end(), res.begin());

return;
Expand Down Expand Up @@ -225,6 +226,22 @@ SEXP set_stats_names_cpp(
return lfmcmc;
}

[[cpp11::register]]
cpp11::writable::doubles get_params_mean_cpp(
SEXP lfmcmc
) {
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
return cpp11::doubles(lfmcmc_ptr->get_params_mean());
}

[[cpp11::register]]
cpp11::writable::doubles get_stats_mean_cpp(
SEXP lfmcmc
) {
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
return cpp11::doubles(lfmcmc_ptr->get_stats_mean());
}

[[cpp11::register]]
SEXP print_lfmcmc_cpp(
SEXP lfmcmc
Expand Down
Loading